query.rs 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. use std::{cell::RefCell, rc::Rc};
  2. use serde::{de::DeserializeOwned, Deserialize};
  3. use serde_json::Value;
  4. use slab::Slab;
  5. use thiserror::Error;
  6. use tokio::sync::{broadcast::error::RecvError, mpsc::UnboundedSender};
  7. /// Tracks what query ids are currently active
  8. #[derive(Default, Clone)]
  9. struct SharedSlab {
  10. slab: Rc<RefCell<Slab<()>>>,
  11. }
  12. /// Handles sending and receiving arbitrary queries from the webview. Queries can be resolved non-sequentially, so we use ids to track them.
  13. #[derive(Clone)]
  14. pub(crate) struct QueryEngine {
  15. sender: Rc<tokio::sync::broadcast::Sender<QueryResult>>,
  16. active_requests: SharedSlab,
  17. }
  18. impl Default for QueryEngine {
  19. fn default() -> Self {
  20. let (sender, _) = tokio::sync::broadcast::channel(8);
  21. Self {
  22. sender: Rc::new(sender),
  23. active_requests: SharedSlab::default(),
  24. }
  25. }
  26. }
  27. impl QueryEngine {
  28. /// Creates a new query and returns a handle to it. The query will be resolved when the webview returns a result with the same id.
  29. pub fn new_query<V: DeserializeOwned>(
  30. &self,
  31. script: &str,
  32. tx: &UnboundedSender<String>,
  33. ) -> Query<V> {
  34. let request_id = self.active_requests.slab.borrow_mut().insert(());
  35. // start the query
  36. // We embed the return of the eval in a function so we can send it back to the main thread
  37. if let Err(err) = tx.send(format!(
  38. r#"window.ipc.postMessage(
  39. JSON.stringify({{
  40. "method":"query",
  41. "params": {{
  42. "id": {request_id},
  43. "data": (function(){{{script}}})()
  44. }}
  45. }})
  46. );"#
  47. )) {
  48. log::warn!("Query error: {err}");
  49. }
  50. Query {
  51. slab: self.active_requests.clone(),
  52. id: request_id,
  53. reciever: self.sender.subscribe(),
  54. phantom: std::marker::PhantomData,
  55. }
  56. }
  57. /// Send a query result
  58. pub fn send(&self, data: QueryResult) {
  59. let _ = self.sender.send(data);
  60. }
  61. }
  62. pub(crate) struct Query<V: DeserializeOwned> {
  63. slab: SharedSlab,
  64. id: usize,
  65. reciever: tokio::sync::broadcast::Receiver<QueryResult>,
  66. phantom: std::marker::PhantomData<V>,
  67. }
  68. impl<V: DeserializeOwned> Query<V> {
  69. /// Resolve the query
  70. pub async fn resolve(mut self) -> Result<V, QueryError> {
  71. let result = loop {
  72. match self.reciever.recv().await {
  73. Ok(result) => {
  74. if result.id == self.id {
  75. break V::deserialize(result.data).map_err(QueryError::DeserializeError);
  76. }
  77. }
  78. Err(err) => {
  79. break Err(QueryError::RecvError(err));
  80. }
  81. }
  82. };
  83. // Remove the query from the slab
  84. self.slab.slab.borrow_mut().remove(self.id);
  85. result
  86. }
  87. }
  88. #[derive(Error, Debug)]
  89. pub enum QueryError {
  90. #[error("Error receiving query result: {0}")]
  91. RecvError(RecvError),
  92. #[error("Error deserializing query result: {0}")]
  93. DeserializeError(serde_json::Error),
  94. }
  95. #[derive(Clone, Debug, Deserialize)]
  96. pub(crate) struct QueryResult {
  97. id: usize,
  98. data: Value,
  99. }