123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- //! Example websocket server.
- //!
- //! Run with
- //!
- //! ```not_rust
- //! cargo run -p example-websockets
- //! ```
- use axum::{
- extract::{
- ws::{Message, WebSocket, WebSocketUpgrade},
- TypedHeader,
- },
- http::StatusCode,
- response::IntoResponse,
- routing::{get, get_service},
- Router,
- };
- use std::net::SocketAddr;
- use tower_http::{
- services::ServeDir,
- trace::{DefaultMakeSpan, TraceLayer},
- };
- #[tokio::main]
- async fn main() {
- // Set the RUST_LOG, if it hasn't been explicitly defined
- if std::env::var_os("RUST_LOG").is_none() {
- std::env::set_var("RUST_LOG", "example_websockets=debug,tower_http=debug")
- }
- tracing_subscriber::fmt::init();
- // build our application with some routes
- let app = Router::new()
- .fallback(
- get_service(
- ServeDir::new("examples/axum_assets").append_index_html_on_directories(true),
- )
- .handle_error(|error: std::io::Error| async move {
- (
- StatusCode::INTERNAL_SERVER_ERROR,
- format!("Unhandled internal error: {}", error),
- )
- }),
- )
- // routes are matched from bottom to top, so we have to put `nest` at the
- // top since it matches all routes
- .route("/ws", get(ws_handler))
- // logging so we can see whats going on
- .layer(
- TraceLayer::new_for_http()
- .make_span_with(DefaultMakeSpan::default().include_headers(true)),
- );
- // run it with hyper
- let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
- tracing::debug!("listening on {}", addr);
- axum::Server::bind(&addr)
- .serve(app.into_make_service())
- .await
- .unwrap();
- }
- async fn ws_handler(
- ws: WebSocketUpgrade,
- user_agent: Option<TypedHeader<headers::UserAgent>>,
- ) -> impl IntoResponse {
- if let Some(TypedHeader(user_agent)) = user_agent {
- println!("`{}` connected", user_agent.as_str());
- }
- ws.on_upgrade(handle_socket)
- }
- async fn handle_socket(mut socket: WebSocket) {
- if let Some(msg) = socket.recv().await {
- if let Ok(msg) = msg {
- println!("Client says: {:?}", msg);
- } else {
- println!("client disconnected");
- return;
- }
- }
- loop {
- if socket
- .send(Message::Text(String::from("Hi!")))
- .await
- .is_err()
- {
- println!("client disconnected");
- return;
- }
- tokio::time::sleep(std::time::Duration::from_secs(1)).await;
- }
- }
|