service.rs 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. //! # Server function Service
  2. //! This module defines a service that can be used to handle server functions.
  3. use http::StatusCode;
  4. use server_fn::{Encoding, Payload};
  5. use std::sync::{Arc, RwLock};
  6. use crate::server_fn::collection::MIDDLEWARE;
  7. use crate::{
  8. layer::{BoxedService, Service},
  9. prelude::{DioxusServerContext, ProvideServerContext},
  10. };
  11. type AxumBody = axum::body::Body;
  12. /// Create a server function handler with the given server context and server function.
  13. pub fn server_fn_service(
  14. context: DioxusServerContext,
  15. function: server_fn::ServerFnTraitObj<()>,
  16. ) -> crate::layer::BoxedService {
  17. let prefix = function.prefix().to_string();
  18. let url = function.url().to_string();
  19. if let Some(middleware) = MIDDLEWARE.get(&(&prefix, &url)) {
  20. let mut service = BoxedService(Box::new(ServerFnHandler::new(context, function)));
  21. for middleware in middleware {
  22. service = middleware.layer(service);
  23. }
  24. service
  25. } else {
  26. BoxedService(Box::new(ServerFnHandler::new(context, function)))
  27. }
  28. }
  29. #[derive(Clone)]
  30. /// A default handler for server functions. It will deserialize the request body, call the server function, and serialize the response.
  31. pub struct ServerFnHandler {
  32. server_context: DioxusServerContext,
  33. function: server_fn::ServerFnTraitObj<()>,
  34. }
  35. impl ServerFnHandler {
  36. /// Create a new server function handler with the given server context and server function.
  37. pub fn new(
  38. server_context: impl Into<DioxusServerContext>,
  39. function: server_fn::ServerFnTraitObj<()>,
  40. ) -> Self {
  41. let server_context = server_context.into();
  42. Self {
  43. server_context,
  44. function,
  45. }
  46. }
  47. }
  48. impl Service for ServerFnHandler {
  49. fn run(
  50. &mut self,
  51. req: http::Request<AxumBody>,
  52. ) -> std::pin::Pin<
  53. Box<
  54. dyn std::future::Future<
  55. Output = Result<http::Response<AxumBody>, server_fn::ServerFnError>,
  56. > + Send,
  57. >,
  58. > {
  59. let Self {
  60. server_context,
  61. function,
  62. } = self.clone();
  63. Box::pin(async move {
  64. let query = req.uri().query().unwrap_or_default().as_bytes().to_vec();
  65. let (parts, body) = req.into_parts();
  66. let body = axum::body::to_bytes(body, usize::MAX).await?.to_vec();
  67. let headers = &parts.headers;
  68. let accept_header = headers.get("Accept").cloned();
  69. let parts = Arc::new(RwLock::new(parts));
  70. // Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime
  71. let pool = get_local_pool();
  72. let result = pool
  73. .spawn_pinned({
  74. let function = function.clone();
  75. let mut server_context = server_context.clone();
  76. server_context.parts = parts;
  77. move || async move {
  78. let data = match function.encoding() {
  79. Encoding::Url | Encoding::Cbor => &body,
  80. Encoding::GetJSON | Encoding::GetCBOR => &query,
  81. };
  82. let server_function_future = function.call((), data);
  83. let server_function_future = ProvideServerContext::new(
  84. server_function_future,
  85. server_context.clone(),
  86. );
  87. server_function_future.await
  88. }
  89. })
  90. .await?;
  91. let mut res = http::Response::builder();
  92. // Set the headers from the server context
  93. let parts = server_context.response_parts().unwrap();
  94. *res.headers_mut().expect("empty headers should be valid") = parts.headers.clone();
  95. let serialized = result?;
  96. // if this is Accept: application/json then send a serialized JSON response
  97. let accept_header = accept_header.as_ref().and_then(|value| value.to_str().ok());
  98. if accept_header == Some("application/json")
  99. || accept_header
  100. == Some(
  101. "application/\
  102. x-www-form-urlencoded",
  103. )
  104. || accept_header == Some("application/cbor")
  105. {
  106. res = res.status(StatusCode::OK);
  107. }
  108. Ok(match serialized {
  109. Payload::Binary(data) => {
  110. res = res.header("Content-Type", "application/cbor");
  111. res.body(data.into())?
  112. }
  113. Payload::Url(data) => {
  114. res = res.header(
  115. "Content-Type",
  116. "application/\
  117. x-www-form-urlencoded",
  118. );
  119. res.body(data.into())?
  120. }
  121. Payload::Json(data) => {
  122. res = res.header("Content-Type", "application/json");
  123. res.body(data.into())?
  124. }
  125. })
  126. })
  127. }
  128. }
  129. fn get_local_pool() -> tokio_util::task::LocalPoolHandle {
  130. use once_cell::sync::OnceCell;
  131. static LOCAL_POOL: OnceCell<tokio_util::task::LocalPoolHandle> = OnceCell::new();
  132. LOCAL_POOL
  133. .get_or_init(|| {
  134. tokio_util::task::LocalPoolHandle::new(
  135. std::thread::available_parallelism()
  136. .map(Into::into)
  137. .unwrap_or(1),
  138. )
  139. })
  140. .clone()
  141. }