axum.rs 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. //! Example websocket server.
  2. //!
  3. //! Run with
  4. //!
  5. //! ```not_rust
  6. //! cargo run -p example-websockets
  7. //! ```
  8. use axum::{
  9. extract::{
  10. ws::{Message, WebSocket, WebSocketUpgrade},
  11. TypedHeader,
  12. },
  13. http::StatusCode,
  14. response::IntoResponse,
  15. routing::{get, get_service},
  16. Router,
  17. };
  18. use std::net::SocketAddr;
  19. use tower_http::{
  20. services::ServeDir,
  21. trace::{DefaultMakeSpan, TraceLayer},
  22. };
  23. #[tokio::main]
  24. async fn main() {
  25. // Set the RUST_LOG, if it hasn't been explicitly defined
  26. if std::env::var_os("RUST_LOG").is_none() {
  27. std::env::set_var("RUST_LOG", "example_websockets=debug,tower_http=debug")
  28. }
  29. tracing_subscriber::fmt::init();
  30. // build our application with some routes
  31. let app = Router::new()
  32. .fallback(
  33. get_service(
  34. ServeDir::new("examples/axum_assets").append_index_html_on_directories(true),
  35. )
  36. .handle_error(|error: std::io::Error| async move {
  37. (
  38. StatusCode::INTERNAL_SERVER_ERROR,
  39. format!("Unhandled internal error: {}", error),
  40. )
  41. }),
  42. )
  43. // routes are matched from bottom to top, so we have to put `nest` at the
  44. // top since it matches all routes
  45. .route("/ws", get(ws_handler))
  46. // logging so we can see whats going on
  47. .layer(
  48. TraceLayer::new_for_http()
  49. .make_span_with(DefaultMakeSpan::default().include_headers(true)),
  50. );
  51. // run it with hyper
  52. let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
  53. tracing::debug!("listening on {}", addr);
  54. axum::Server::bind(&addr)
  55. .serve(app.into_make_service())
  56. .await
  57. .unwrap();
  58. }
  59. async fn ws_handler(
  60. ws: WebSocketUpgrade,
  61. user_agent: Option<TypedHeader<headers::UserAgent>>,
  62. ) -> impl IntoResponse {
  63. if let Some(TypedHeader(user_agent)) = user_agent {
  64. println!("`{}` connected", user_agent.as_str());
  65. }
  66. ws.on_upgrade(handle_socket)
  67. }
  68. async fn handle_socket(mut socket: WebSocket) {
  69. if let Some(msg) = socket.recv().await {
  70. if let Ok(msg) = msg {
  71. println!("Client says: {:?}", msg);
  72. } else {
  73. println!("client disconnected");
  74. return;
  75. }
  76. }
  77. loop {
  78. if socket
  79. .send(Message::Text(String::from("Hi!")))
  80. .await
  81. .is_err()
  82. {
  83. println!("client disconnected");
  84. return;
  85. }
  86. tokio::time::sleep(std::time::Duration::from_secs(1)).await;
  87. }
  88. }