use crate::{DesktopContext, WeakDesktopContext}; use futures_util::{FutureExt, StreamExt}; use generational_box::Owner; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use slab::Slab; use std::{cell::RefCell, rc::Rc}; use thiserror::Error; /// Tracks what query ids are currently active pub(crate) struct SharedSlab { pub slab: Rc>>, } impl Clone for SharedSlab { fn clone(&self) -> Self { Self { slab: self.slab.clone(), } } } impl Default for SharedSlab { fn default() -> Self { SharedSlab { slab: Rc::new(RefCell::new(Slab::new())), } } } pub(crate) struct QueryEntry { channel_sender: futures_channel::mpsc::UnboundedSender, return_sender: Option>>, pub owner: Option, } /// Handles sending and receiving arbitrary queries from the webview. Queries can be resolved non-sequentially, so we use ids to track them. #[derive(Clone, Default)] pub(crate) struct QueryEngine { pub active_requests: SharedSlab, } impl QueryEngine { /// Creates a new query and returns a handle to it. The query will be resolved when the webview returns a result with the same id. pub fn new_query( &self, script: &str, context: DesktopContext, ) -> Query { let (tx, rx) = futures_channel::mpsc::unbounded(); let (return_tx, return_rx) = futures_channel::oneshot::channel(); let request_id = self.active_requests.slab.borrow_mut().insert(QueryEntry { channel_sender: tx, return_sender: Some(return_tx), owner: None, }); // start the query // We embed the return of the eval in a function so we can send it back to the main thread if let Err(err) = context.webview.evaluate_script(&format!( r#"(function(){{ let dioxus = window.createQuery({request_id}); let post_error = function(err) {{ let returned_value = {{ "method": "query", "params": {{ "id": {request_id}, "data": {{ "data": err, "method": "return_error" }} }} }}; window.ipc.postMessage( JSON.stringify(returned_value) ); }}; try {{ const AsyncFunction = async function () {{}}.constructor; let promise = (new AsyncFunction("dioxus", {script:?}))(dioxus); promise .then((result)=>{{ dioxus.close(); let returned_value = {{ "method": "query", "params": {{ "id": {request_id}, "data": {{ "data": result, "method": "return" }} }} }}; window.ipc.postMessage( JSON.stringify(returned_value) ); }}) .catch(err => post_error(`Error running JS: ${{err}}`)); }} catch (error) {{ dioxus.close(); post_error(`Invalid JS: ${{error}}`); }} }})();"# )) { tracing::warn!("Query error: {err}"); } Query { id: request_id, receiver: rx, return_receiver: Some(return_rx), desktop: Rc::downgrade(&context), phantom: std::marker::PhantomData, } } /// Send a query channel message to the correct query pub fn send(&self, data: QueryResult) { let QueryResult { id, data } = data; let mut slab = self.active_requests.slab.borrow_mut(); if let Some(entry) = slab.get_mut(id) { match data { QueryResultData::Return { data } => { if let Some(sender) = entry.return_sender.take() { let _ = sender.send(Ok(data.unwrap_or_default())); } } QueryResultData::ReturnError { data } => { if let Some(sender) = entry.return_sender.take() { let _ = sender.send(Err(data.to_string())); } } QueryResultData::Drop => { slab.remove(id); } QueryResultData::Send { data } => { let _ = entry.channel_sender.unbounded_send(data); } } } } } pub(crate) struct Query { desktop: WeakDesktopContext, receiver: futures_channel::mpsc::UnboundedReceiver, return_receiver: Option>>, pub id: usize, phantom: std::marker::PhantomData, } impl Query { /// Resolve the query pub async fn resolve(mut self) -> Result { let result = self.result().await?; V::deserialize(result).map_err(QueryError::Deserialize) } /// Send a message to the query pub fn send(&self, message: S) -> Result<(), QueryError> { let queue_id = self.id; let data = message.to_string(); let script = format!(r#"window.getQuery({queue_id}).rustSend({data});"#); let desktop = self.desktop.upgrade().ok_or(QueryError::Finished)?; desktop .webview .evaluate_script(&script) .map_err(|e| QueryError::Send(e.to_string()))?; Ok(()) } /// Poll the query for a message pub fn poll_recv( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.receiver .poll_next_unpin(cx) .map(|result| result.ok_or(QueryError::Recv(String::from("Receive channel closed")))) } /// Receive the result of the query pub async fn result(&mut self) -> Result { match self.return_receiver.take() { Some(receiver) => match receiver.await { Ok(Ok(data)) => Ok(data), Ok(Err(err)) => Err(QueryError::Recv(err)), Err(err) => Err(QueryError::Recv(err.to_string())), }, None => Err(QueryError::Finished), } } /// Poll the query for a result pub fn poll_result( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match self.return_receiver.as_mut() { Some(receiver) => receiver.poll_unpin(cx).map(|result| match result { Ok(Ok(data)) => Ok(data), Ok(Err(err)) => Err(QueryError::Recv(err)), Err(err) => Err(QueryError::Recv(err.to_string())), }), None => std::task::Poll::Ready(Err(QueryError::Finished)), } } } #[derive(Error, Debug)] #[non_exhaustive] pub enum QueryError { #[error("Error receiving query result: {0}")] Recv(String), #[error("Error sending message to query: {0}")] Send(String), #[error("Error deserializing query result: {0}")] Deserialize(serde_json::Error), #[error("Query has already been resolved")] Finished, } #[derive(Clone, Debug, Deserialize)] pub(crate) struct QueryResult { id: usize, data: QueryResultData, } #[derive(Clone, Debug, Deserialize)] #[serde(tag = "method")] enum QueryResultData { #[serde(rename = "return")] Return { data: Option }, #[serde(rename = "return_error")] ReturnError { data: Value }, #[serde(rename = "send")] Send { data: Value }, #[serde(rename = "drop")] Drop, }