123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- //! # Server function Service
- //! This module defines a service that can be used to handle server functions.
- use http::StatusCode;
- use server_fn::{Encoding, Payload};
- use std::sync::{Arc, RwLock};
- use crate::server_fn::collection::MIDDLEWARE;
- use crate::{
- layer::{BoxedService, Service},
- prelude::{DioxusServerContext, ProvideServerContext},
- };
- type AxumBody = axum::body::Body;
- /// Create a server function handler with the given server context and server function.
- pub fn server_fn_service(
- context: DioxusServerContext,
- function: server_fn::ServerFnTraitObj<()>,
- ) -> crate::layer::BoxedService {
- let prefix = function.prefix().to_string();
- let url = function.url().to_string();
- if let Some(middleware) = MIDDLEWARE.get(&(&prefix, &url)) {
- let mut service = BoxedService(Box::new(ServerFnHandler::new(context, function)));
- for middleware in middleware {
- service = middleware.layer(service);
- }
- service
- } else {
- BoxedService(Box::new(ServerFnHandler::new(context, function)))
- }
- }
- #[derive(Clone)]
- /// A default handler for server functions. It will deserialize the request body, call the server function, and serialize the response.
- pub struct ServerFnHandler {
- server_context: DioxusServerContext,
- function: server_fn::ServerFnTraitObj<()>,
- }
- impl ServerFnHandler {
- /// Create a new server function handler with the given server context and server function.
- pub fn new(
- server_context: impl Into<DioxusServerContext>,
- function: server_fn::ServerFnTraitObj<()>,
- ) -> Self {
- let server_context = server_context.into();
- Self {
- server_context,
- function,
- }
- }
- }
- impl Service for ServerFnHandler {
- fn run(
- &mut self,
- req: http::Request<AxumBody>,
- ) -> std::pin::Pin<
- Box<
- dyn std::future::Future<
- Output = Result<http::Response<AxumBody>, server_fn::ServerFnError>,
- > + Send,
- >,
- > {
- let Self {
- server_context,
- function,
- } = self.clone();
- Box::pin(async move {
- let query = req.uri().query().unwrap_or_default().as_bytes().to_vec();
- let (parts, body) = req.into_parts();
- let body = axum::body::to_bytes(body, usize::MAX).await?.to_vec();
- let headers = &parts.headers;
- let accept_header = headers.get("Accept").cloned();
- let parts = Arc::new(RwLock::new(parts));
- // Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime
- let pool = get_local_pool();
- let result = pool
- .spawn_pinned({
- let function = function.clone();
- let mut server_context = server_context.clone();
- server_context.parts = parts;
- move || async move {
- let data = match function.encoding() {
- Encoding::Url | Encoding::Cbor => &body,
- Encoding::GetJSON | Encoding::GetCBOR => &query,
- };
- let server_function_future = function.call((), data);
- let server_function_future = ProvideServerContext::new(
- server_function_future,
- server_context.clone(),
- );
- server_function_future.await
- }
- })
- .await?;
- let mut res = http::Response::builder();
- // Set the headers from the server context
- let parts = server_context.response_parts().unwrap();
- *res.headers_mut().expect("empty headers should be valid") = parts.headers.clone();
- let serialized = result?;
- // if this is Accept: application/json then send a serialized JSON response
- let accept_header = accept_header.as_ref().and_then(|value| value.to_str().ok());
- if accept_header == Some("application/json")
- || accept_header
- == Some(
- "application/\
- x-www-form-urlencoded",
- )
- || accept_header == Some("application/cbor")
- {
- res = res.status(StatusCode::OK);
- }
- Ok(match serialized {
- Payload::Binary(data) => {
- res = res.header("Content-Type", "application/cbor");
- res.body(data.into())?
- }
- Payload::Url(data) => {
- res = res.header(
- "Content-Type",
- "application/\
- x-www-form-urlencoded",
- );
- res.body(data.into())?
- }
- Payload::Json(data) => {
- res = res.header("Content-Type", "application/json");
- res.body(data.into())?
- }
- })
- })
- }
- }
- fn get_local_pool() -> tokio_util::task::LocalPoolHandle {
- use once_cell::sync::OnceCell;
- static LOCAL_POOL: OnceCell<tokio_util::task::LocalPoolHandle> = OnceCell::new();
- LOCAL_POOL
- .get_or_init(|| {
- tokio_util::task::LocalPoolHandle::new(
- std::thread::available_parallelism()
- .map(Into::into)
- .unwrap_or(1),
- )
- })
- .clone()
- }
|