use async_trait::async_trait; use axum::{ http::Method, response::{IntoResponse, Response}, routing::get, Router, }; use axum_session::{SessionConfig, SessionLayer, SessionSqlitePool, SessionStore}; use axum_session_auth::*; use core::pin::Pin; use dioxus_fullstack::prelude::*; use serde::{Deserialize, Serialize}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use std::error::Error; use std::future::Future; use std::{collections::HashSet, net::SocketAddr, str::FromStr}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct User { pub id: i32, pub anonymous: bool, pub username: String, pub permissions: HashSet, } #[derive(sqlx::FromRow, Clone)] pub struct SqlPermissionTokens { pub token: String, } impl Default for User { fn default() -> Self { let mut permissions = HashSet::new(); permissions.insert("Category::View".to_owned()); Self { id: 1, anonymous: true, username: "Guest".into(), permissions, } } } #[async_trait] impl Authentication for User { async fn load_user(userid: i64, pool: Option<&SqlitePool>) -> Result { let pool = pool.unwrap(); User::get_user(userid, pool) .await .ok_or_else(|| anyhow::anyhow!("Could not load user")) } fn is_authenticated(&self) -> bool { !self.anonymous } fn is_active(&self) -> bool { !self.anonymous } fn is_anonymous(&self) -> bool { self.anonymous } } #[async_trait] impl HasPermission for User { async fn has(&self, perm: &str, _pool: &Option<&SqlitePool>) -> bool { self.permissions.contains(perm) } } impl User { pub async fn get_user(id: i64, pool: &SqlitePool) -> Option { let sqluser = sqlx::query_as::<_, SqlUser>("SELECT * FROM users WHERE id = $1") .bind(id) .fetch_one(pool) .await .ok()?; //lets just get all the tokens the user can use, we will only use the full permissions if modifying them. let sql_user_perms = sqlx::query_as::<_, SqlPermissionTokens>( "SELECT token FROM user_permissions WHERE user_id = $1;", ) .bind(id) .fetch_all(pool) .await .ok()?; Some(sqluser.into_user(Some(sql_user_perms))) } pub async fn create_user_tables(pool: &SqlitePool) { sqlx::query( r#" CREATE TABLE IF NOT EXISTS users ( "id" INTEGER PRIMARY KEY, "anonymous" BOOLEAN NOT NULL, "username" VARCHAR(256) NOT NULL ) "#, ) .execute(pool) .await .unwrap(); sqlx::query( r#" CREATE TABLE IF NOT EXISTS user_permissions ( "user_id" INTEGER NOT NULL, "token" VARCHAR(256) NOT NULL ) "#, ) .execute(pool) .await .unwrap(); sqlx::query( r#" INSERT INTO users (id, anonymous, username) SELECT 1, true, 'Guest' ON CONFLICT(id) DO UPDATE SET anonymous = EXCLUDED.anonymous, username = EXCLUDED.username "#, ) .execute(pool) .await .unwrap(); sqlx::query( r#" INSERT INTO users (id, anonymous, username) SELECT 2, false, 'Test' ON CONFLICT(id) DO UPDATE SET anonymous = EXCLUDED.anonymous, username = EXCLUDED.username "#, ) .execute(pool) .await .unwrap(); sqlx::query( r#" INSERT INTO user_permissions (user_id, token) SELECT 2, 'Category::View' "#, ) .execute(pool) .await .unwrap(); } } #[derive(sqlx::FromRow, Clone)] pub struct SqlUser { pub id: i32, pub anonymous: bool, pub username: String, } impl SqlUser { pub fn into_user(self, sql_user_perms: Option>) -> User { User { id: self.id, anonymous: self.anonymous, username: self.username, permissions: if let Some(user_perms) = sql_user_perms { user_perms .into_iter() .map(|x| x.token) .collect::>() } else { HashSet::::new() }, } } } pub async fn connect_to_database() -> SqlitePool { let connect_opts = SqliteConnectOptions::from_str("sqlite::memory:").unwrap(); SqlitePoolOptions::new() .max_connections(5) .connect_with(connect_opts) .await .unwrap() } pub struct Session( pub axum_session_auth::AuthSession< crate::auth::User, i64, axum_session_auth::SessionSqlitePool, sqlx::SqlitePool, >, ); impl std::ops::Deref for Session { type Target = axum_session_auth::AuthSession< crate::auth::User, i64, axum_session_auth::SessionSqlitePool, sqlx::SqlitePool, >; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for Session { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } #[derive(Debug)] pub struct AuthSessionLayerNotFound; impl std::fmt::Display for AuthSessionLayerNotFound { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "AuthSessionLayer was not found") } } impl std::error::Error for AuthSessionLayerNotFound {} impl IntoResponse for AuthSessionLayerNotFound { fn into_response(self) -> Response { ( http::status::StatusCode::INTERNAL_SERVER_ERROR, "AuthSessionLayer was not found", ) .into_response() } } #[async_trait] impl axum::extract::FromRequestParts for Session { type Rejection = AuthSessionLayerNotFound; async fn from_request_parts( parts: &mut http::request::Parts, state: &S, ) -> Result { axum_session_auth::AuthSession::< crate::auth::User, i64, axum_session_auth::SessionSqlitePool, sqlx::SqlitePool, >::from_request_parts(parts, state) .await .map(Session) .map_err(|_| AuthSessionLayerNotFound) } }