123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- use async_trait::async_trait;
- use axum::{
- http::Method,
- response::{IntoResponse, Response},
- routing::get,
- Router,
- };
- use axum_session::{SessionConfig, SessionLayer, SessionSqlitePool, SessionStore};
- use axum_session_auth::*;
- use core::pin::Pin;
- use dioxus_fullstack::prelude::*;
- use serde::{Deserialize, Serialize};
- use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
- use std::error::Error;
- use std::future::Future;
- use std::{collections::HashSet, net::SocketAddr, str::FromStr};
- #[derive(Debug, Clone, Serialize, Deserialize)]
- pub struct User {
- pub id: i32,
- pub anonymous: bool,
- pub username: String,
- pub permissions: HashSet<String>,
- }
- #[derive(sqlx::FromRow, Clone)]
- pub struct SqlPermissionTokens {
- pub token: String,
- }
- impl Default for User {
- fn default() -> Self {
- let mut permissions = HashSet::new();
- permissions.insert("Category::View".to_owned());
- Self {
- id: 1,
- anonymous: true,
- username: "Guest".into(),
- permissions,
- }
- }
- }
- #[async_trait]
- impl Authentication<User, i64, SqlitePool> for User {
- async fn load_user(userid: i64, pool: Option<&SqlitePool>) -> Result<User, anyhow::Error> {
- let pool = pool.unwrap();
- User::get_user(userid, pool)
- .await
- .ok_or_else(|| anyhow::anyhow!("Could not load user"))
- }
- fn is_authenticated(&self) -> bool {
- !self.anonymous
- }
- fn is_active(&self) -> bool {
- !self.anonymous
- }
- fn is_anonymous(&self) -> bool {
- self.anonymous
- }
- }
- #[async_trait]
- impl HasPermission<SqlitePool> for User {
- async fn has(&self, perm: &str, _pool: &Option<&SqlitePool>) -> bool {
- self.permissions.contains(perm)
- }
- }
- impl User {
- pub async fn get_user(id: i64, pool: &SqlitePool) -> Option<Self> {
- let sqluser = sqlx::query_as::<_, SqlUser>("SELECT * FROM users WHERE id = $1")
- .bind(id)
- .fetch_one(pool)
- .await
- .ok()?;
- //lets just get all the tokens the user can use, we will only use the full permissions if modifying them.
- let sql_user_perms = sqlx::query_as::<_, SqlPermissionTokens>(
- "SELECT token FROM user_permissions WHERE user_id = $1;",
- )
- .bind(id)
- .fetch_all(pool)
- .await
- .ok()?;
- Some(sqluser.into_user(Some(sql_user_perms)))
- }
- pub async fn create_user_tables(pool: &SqlitePool) {
- sqlx::query(
- r#"
- CREATE TABLE IF NOT EXISTS users (
- "id" INTEGER PRIMARY KEY,
- "anonymous" BOOLEAN NOT NULL,
- "username" VARCHAR(256) NOT NULL
- )
- "#,
- )
- .execute(pool)
- .await
- .unwrap();
- sqlx::query(
- r#"
- CREATE TABLE IF NOT EXISTS user_permissions (
- "user_id" INTEGER NOT NULL,
- "token" VARCHAR(256) NOT NULL
- )
- "#,
- )
- .execute(pool)
- .await
- .unwrap();
- sqlx::query(
- r#"
- INSERT INTO users
- (id, anonymous, username) SELECT 1, true, 'Guest'
- ON CONFLICT(id) DO UPDATE SET
- anonymous = EXCLUDED.anonymous,
- username = EXCLUDED.username
- "#,
- )
- .execute(pool)
- .await
- .unwrap();
- sqlx::query(
- r#"
- INSERT INTO users
- (id, anonymous, username) SELECT 2, false, 'Test'
- ON CONFLICT(id) DO UPDATE SET
- anonymous = EXCLUDED.anonymous,
- username = EXCLUDED.username
- "#,
- )
- .execute(pool)
- .await
- .unwrap();
- sqlx::query(
- r#"
- INSERT INTO user_permissions
- (user_id, token) SELECT 2, 'Category::View'
- "#,
- )
- .execute(pool)
- .await
- .unwrap();
- }
- }
- #[derive(sqlx::FromRow, Clone)]
- pub struct SqlUser {
- pub id: i32,
- pub anonymous: bool,
- pub username: String,
- }
- impl SqlUser {
- pub fn into_user(self, sql_user_perms: Option<Vec<SqlPermissionTokens>>) -> User {
- User {
- id: self.id,
- anonymous: self.anonymous,
- username: self.username,
- permissions: if let Some(user_perms) = sql_user_perms {
- user_perms
- .into_iter()
- .map(|x| x.token)
- .collect::<HashSet<String>>()
- } else {
- HashSet::<String>::new()
- },
- }
- }
- }
- pub async fn connect_to_database() -> SqlitePool {
- let connect_opts = SqliteConnectOptions::from_str("sqlite::memory:").unwrap();
- SqlitePoolOptions::new()
- .max_connections(5)
- .connect_with(connect_opts)
- .await
- .unwrap()
- }
- pub struct Session(
- pub axum_session_auth::AuthSession<
- crate::auth::User,
- i64,
- axum_session_auth::SessionSqlitePool,
- sqlx::SqlitePool,
- >,
- );
- impl std::ops::Deref for Session {
- type Target = axum_session_auth::AuthSession<
- crate::auth::User,
- i64,
- axum_session_auth::SessionSqlitePool,
- sqlx::SqlitePool,
- >;
- fn deref(&self) -> &Self::Target {
- &self.0
- }
- }
- impl std::ops::DerefMut for Session {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.0
- }
- }
- #[derive(Debug)]
- pub struct AuthSessionLayerNotFound;
- impl std::fmt::Display for AuthSessionLayerNotFound {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "AuthSessionLayer was not found")
- }
- }
- impl std::error::Error for AuthSessionLayerNotFound {}
- impl IntoResponse for AuthSessionLayerNotFound {
- fn into_response(self) -> Response {
- (
- http::status::StatusCode::INTERNAL_SERVER_ERROR,
- "AuthSessionLayer was not found",
- )
- .into_response()
- }
- }
- #[async_trait]
- impl<S: std::marker::Sync + std::marker::Send> axum::extract::FromRequestParts<S> for Session {
- type Rejection = AuthSessionLayerNotFound;
- async fn from_request_parts(
- parts: &mut http::request::Parts,
- state: &S,
- ) -> Result<Self, Self::Rejection> {
- axum_session_auth::AuthSession::<
- crate::auth::User,
- i64,
- axum_session_auth::SessionSqlitePool,
- sqlx::SqlitePool,
- >::from_request_parts(parts, state)
- .await
- .map(Session)
- .map_err(|_| AuthSessionLayerNotFound)
- }
- }
|