query.rs 8.2 KB


  1. use crate::{DesktopContext, WeakDesktopContext};
  2. use futures_util::{FutureExt, StreamExt};
  3. use generational_box::Owner;
  4. use serde::{de::DeserializeOwned, Deserialize};
  5. use serde_json::Value;
  6. use slab::Slab;
  7. use std::{cell::RefCell, rc::Rc};
  8. use thiserror::Error;
  9. /// Tracks what query ids are currently active
  10. pub(crate) struct SharedSlab<T = ()> {
  11. pub slab: Rc<RefCell<Slab<T>>>,
  12. }
  13. impl<T> Clone for SharedSlab<T> {
  14. fn clone(&self) -> Self {
  15. Self {
  16. slab: self.slab.clone(),
  17. }
  18. }
  19. }
  20. impl<T> Default for SharedSlab<T> {
  21. fn default() -> Self {
  22. SharedSlab {
  23. slab: Rc::new(RefCell::new(Slab::new())),
  24. }
  25. }
  26. }
  27. pub(crate) struct QueryEntry {
  28. channel_sender: futures_channel::mpsc::UnboundedSender<Value>,
  29. return_sender: Option<futures_channel::oneshot::Sender<Result<Value, String>>>,
  30. pub owner: Option<Owner>,
  31. }
  32. /// Handles sending and receiving arbitrary queries from the webview. Queries can be resolved non-sequentially, so we use ids to track them.
  33. #[derive(Clone, Default)]
  34. pub(crate) struct QueryEngine {
  35. pub active_requests: SharedSlab<QueryEntry>,
  36. }
  37. impl QueryEngine {
  38. /// Creates a new query and returns a handle to it. The query will be resolved when the webview returns a result with the same id.
  39. pub fn new_query<V: DeserializeOwned>(
  40. &self,
  41. script: &str,
  42. context: DesktopContext,
  43. ) -> Query<V> {
  44. let (tx, rx) = futures_channel::mpsc::unbounded();
  45. let (return_tx, return_rx) = futures_channel::oneshot::channel();
  46. let request_id = self.active_requests.slab.borrow_mut().insert(QueryEntry {
  47. channel_sender: tx,
  48. return_sender: Some(return_tx),
  49. owner: None,
  50. });
  51. // start the query
  52. // We embed the return of the eval in a function so we can send it back to the main thread
  53. if let Err(err) = context.webview.evaluate_script(&format!(
  54. r#"(function(){{
  55. let dioxus = window.createQuery({request_id});
  56. let post_error = function(err) {{
  57. let returned_value = {{
  58. "method": "query",
  59. "params": {{
  60. "id": {request_id},
  61. "data": {{
  62. "data": err,
  63. "method": "return_error"
  64. }}
  65. }}
  66. }};
  67. window.ipc.postMessage(
  68. JSON.stringify(returned_value)
  69. );
  70. }};
  71. try {{
  72. const AsyncFunction = async function () {{}}.constructor;
  73. let promise = (new AsyncFunction("dioxus", {script:?}))(dioxus);
  74. promise
  75. .then((result)=>{{
  76. dioxus.close();
  77. let returned_value = {{
  78. "method": "query",
  79. "params": {{
  80. "id": {request_id},
  81. "data": {{
  82. "data": result,
  83. "method": "return"
  84. }}
  85. }}
  86. }};
  87. window.ipc.postMessage(
  88. JSON.stringify(returned_value)
  89. );
  90. }})
  91. .catch(err => post_error(`Error running JS: ${{err}}`));
  92. }} catch (error) {{
  93. dioxus.close();
  94. post_error(`Invalid JS: ${{error}}`);
  95. }}
  96. }})();"#
  97. )) {
  98. tracing::warn!("Query error: {err}");
  99. }
  100. Query {
  101. id: request_id,
  102. receiver: rx,
  103. return_receiver: Some(return_rx),
  104. desktop: Rc::downgrade(&context),
  105. phantom: std::marker::PhantomData,
  106. }
  107. }
  108. /// Send a query channel message to the correct query
  109. pub fn send(&self, data: QueryResult) {
  110. let QueryResult { id, data } = data;
  111. let mut slab = self.active_requests.slab.borrow_mut();
  112. if let Some(entry) = slab.get_mut(id) {
  113. match data {
  114. QueryResultData::Return { data } => {
  115. if let Some(sender) = entry.return_sender.take() {
  116. let _ = sender.send(Ok(data.unwrap_or_default()));
  117. }
  118. }
  119. QueryResultData::ReturnError { data } => {
  120. if let Some(sender) = entry.return_sender.take() {
  121. let _ = sender.send(Err(data.to_string()));
  122. }
  123. }
  124. QueryResultData::Drop => {
  125. slab.remove(id);
  126. }
  127. QueryResultData::Send { data } => {
  128. let _ = entry.channel_sender.unbounded_send(data);
  129. }
  130. }
  131. }
  132. }
  133. }
  134. pub(crate) struct Query<V: DeserializeOwned> {
  135. desktop: WeakDesktopContext,
  136. receiver: futures_channel::mpsc::UnboundedReceiver<Value>,
  137. return_receiver: Option<futures_channel::oneshot::Receiver<Result<Value, String>>>,
  138. pub id: usize,
  139. phantom: std::marker::PhantomData<V>,
  140. }
  141. impl<V: DeserializeOwned> Query<V> {
  142. /// Resolve the query
  143. pub async fn resolve(mut self) -> Result<V, QueryError> {
  144. let result = self.result().await?;
  145. V::deserialize(result).map_err(QueryError::Deserialize)
  146. }
  147. /// Send a message to the query
  148. pub fn send<S: ToString>(&self, message: S) -> Result<(), QueryError> {
  149. let queue_id = self.id;
  150. let data = message.to_string();
  151. let script = format!(r#"window.getQuery({queue_id}).rustSend({data});"#);
  152. let desktop = self.desktop.upgrade().ok_or(QueryError::Finished)?;
  153. desktop
  154. .webview
  155. .evaluate_script(&script)
  156. .map_err(|e| QueryError::Send(e.to_string()))?;
  157. Ok(())
  158. }
  159. /// Poll the query for a message
  160. pub fn poll_recv(
  161. &mut self,
  162. cx: &mut std::task::Context<'_>,
  163. ) -> std::task::Poll<Result<Value, QueryError>> {
  164. self.receiver
  165. .poll_next_unpin(cx)
  166. .map(|result| result.ok_or(QueryError::Recv(String::from("Receive channel closed"))))
  167. }
  168. /// Receive the result of the query
  169. pub async fn result(&mut self) -> Result<Value, QueryError> {
  170. match self.return_receiver.take() {
  171. Some(receiver) => match receiver.await {
  172. Ok(Ok(data)) => Ok(data),
  173. Ok(Err(err)) => Err(QueryError::Recv(err)),
  174. Err(err) => Err(QueryError::Recv(err.to_string())),
  175. },
  176. None => Err(QueryError::Finished),
  177. }
  178. }
  179. /// Poll the query for a result
  180. pub fn poll_result(
  181. &mut self,
  182. cx: &mut std::task::Context<'_>,
  183. ) -> std::task::Poll<Result<Value, QueryError>> {
  184. match self.return_receiver.as_mut() {
  185. Some(receiver) => receiver.poll_unpin(cx).map(|result| match result {
  186. Ok(Ok(data)) => Ok(data),
  187. Ok(Err(err)) => Err(QueryError::Recv(err)),
  188. Err(err) => Err(QueryError::Recv(err.to_string())),
  189. }),
  190. None => std::task::Poll::Ready(Err(QueryError::Finished)),
  191. }
  192. }
  193. }
  194. #[derive(Error, Debug)]
  195. #[non_exhaustive]
  196. pub enum QueryError {
  197. #[error("Error receiving query result: {0}")]
  198. Recv(String),
  199. #[error("Error sending message to query: {0}")]
  200. Send(String),
  201. #[error("Error deserializing query result: {0}")]
  202. Deserialize(serde_json::Error),
  203. #[error("Query has already been resolved")]
  204. Finished,
  205. }
  206. #[derive(Clone, Debug, Deserialize)]
  207. pub(crate) struct QueryResult {
  208. id: usize,
  209. data: QueryResultData,
  210. }
  211. #[derive(Clone, Debug, Deserialize)]
  212. #[serde(tag = "method")]
  213. enum QueryResultData {
  214. #[serde(rename = "return")]
  215. Return { data: Option<Value> },
  216. #[serde(rename = "return_error")]
  217. ReturnError { data: Value },
  218. #[serde(rename = "send")]
  219. Send { data: Value },
  220. #[serde(rename = "drop")]
  221. Drop,
  222. }