query.rs 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. use std::{cell::RefCell, rc::Rc};
  2. use serde::{de::DeserializeOwned, Deserialize};
  3. use serde_json::Value;
  4. use slab::Slab;
  5. use thiserror::Error;
  6. use tokio::sync::broadcast::error::RecvError;
  7. use wry::webview::WebView;
  8. /// Tracks what query ids are currently active
  9. #[derive(Default, Clone)]
  10. struct SharedSlab {
  11. slab: Rc<RefCell<Slab<()>>>,
  12. }
  13. /// Handles sending and receiving arbitrary queries from the webview. Queries can be resolved non-sequentially, so we use ids to track them.
  14. #[derive(Clone)]
  15. pub(crate) struct QueryEngine {
  16. sender: Rc<tokio::sync::broadcast::Sender<QueryResult>>,
  17. active_requests: SharedSlab,
  18. }
  19. impl Default for QueryEngine {
  20. fn default() -> Self {
  21. let (sender, _) = tokio::sync::broadcast::channel(1000);
  22. Self {
  23. sender: Rc::new(sender),
  24. active_requests: SharedSlab::default(),
  25. }
  26. }
  27. }
  28. impl QueryEngine {
  29. /// Creates a new query and returns a handle to it. The query will be resolved when the webview returns a result with the same id.
  30. pub fn new_query<V: DeserializeOwned>(&self, script: &str, webview: &WebView) -> Query<V> {
  31. let request_id = self.active_requests.slab.borrow_mut().insert(());
  32. // start the query
  33. // We embed the return of the eval in a function so we can send it back to the main thread
  34. if let Err(err) = webview.evaluate_script(&format!(
  35. r#"window.ipc.postMessage(
  36. JSON.stringify({{
  37. "method":"query",
  38. "params": {{
  39. "id": {request_id},
  40. "data": (function(){{{script}}})()
  41. }}
  42. }})
  43. );"#
  44. )) {
  45. log::warn!("Query error: {err}");
  46. }
  47. Query {
  48. slab: self.active_requests.clone(),
  49. id: request_id,
  50. reciever: self.sender.subscribe(),
  51. phantom: std::marker::PhantomData,
  52. }
  53. }
  54. /// Send a query result
  55. pub fn send(&self, data: QueryResult) {
  56. let _ = self.sender.send(data);
  57. }
  58. }
  59. pub(crate) struct Query<V: DeserializeOwned> {
  60. slab: SharedSlab,
  61. id: usize,
  62. reciever: tokio::sync::broadcast::Receiver<QueryResult>,
  63. phantom: std::marker::PhantomData<V>,
  64. }
  65. impl<V: DeserializeOwned> Query<V> {
  66. /// Resolve the query
  67. pub async fn resolve(mut self) -> Result<V, QueryError> {
  68. let result = loop {
  69. match self.reciever.recv().await {
  70. Ok(result) => {
  71. if result.id == self.id {
  72. break V::deserialize(result.data).map_err(QueryError::DeserializeError);
  73. }
  74. }
  75. Err(err) => {
  76. break Err(QueryError::RecvError(err));
  77. }
  78. }
  79. };
  80. // Remove the query from the slab
  81. self.slab.slab.borrow_mut().remove(self.id);
  82. result
  83. }
  84. }
  85. #[derive(Error, Debug)]
  86. pub enum QueryError {
  87. #[error("Error receiving query result: {0}")]
  88. RecvError(RecvError),
  89. #[error("Error deserializing query result: {0}")]
  90. DeserializeError(serde_json::Error),
  91. }
  92. #[derive(Clone, Debug, Deserialize)]
  93. pub(crate) struct QueryResult {
  94. id: usize,
  95. data: Value,
  96. }