auth.rs 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. use async_trait::async_trait;
  2. use axum::{
  3. http::Method,
  4. response::{IntoResponse, Response},
  5. routing::get,
  6. Router,
  7. };
  8. use axum_session::{SessionConfig, SessionLayer, SessionSqlitePool, SessionStore};
  9. use axum_session_auth::*;
  10. use core::pin::Pin;
  11. use dioxus_fullstack::prelude::*;
  12. use serde::{Deserialize, Serialize};
  13. use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
  14. use std::error::Error;
  15. use std::future::Future;
  16. use std::{collections::HashSet, net::SocketAddr, str::FromStr};
  17. #[derive(Debug, Clone, Serialize, Deserialize)]
  18. pub struct User {
  19. pub id: i32,
  20. pub anonymous: bool,
  21. pub username: String,
  22. pub permissions: HashSet<String>,
  23. }
  24. #[derive(sqlx::FromRow, Clone)]
  25. pub struct SqlPermissionTokens {
  26. pub token: String,
  27. }
  28. impl Default for User {
  29. fn default() -> Self {
  30. let mut permissions = HashSet::new();
  31. permissions.insert("Category::View".to_owned());
  32. Self {
  33. id: 1,
  34. anonymous: true,
  35. username: "Guest".into(),
  36. permissions,
  37. }
  38. }
  39. }
  40. #[async_trait]
  41. impl Authentication<User, i64, SqlitePool> for User {
  42. async fn load_user(userid: i64, pool: Option<&SqlitePool>) -> Result<User, anyhow::Error> {
  43. let pool = pool.unwrap();
  44. User::get_user(userid, pool)
  45. .await
  46. .ok_or_else(|| anyhow::anyhow!("Could not load user"))
  47. }
  48. fn is_authenticated(&self) -> bool {
  49. !self.anonymous
  50. }
  51. fn is_active(&self) -> bool {
  52. !self.anonymous
  53. }
  54. fn is_anonymous(&self) -> bool {
  55. self.anonymous
  56. }
  57. }
  58. #[async_trait]
  59. impl HasPermission<SqlitePool> for User {
  60. async fn has(&self, perm: &str, _pool: &Option<&SqlitePool>) -> bool {
  61. self.permissions.contains(perm)
  62. }
  63. }
  64. impl User {
  65. pub async fn get_user(id: i64, pool: &SqlitePool) -> Option<Self> {
  66. let sqluser = sqlx::query_as::<_, SqlUser>("SELECT * FROM users WHERE id = $1")
  67. .bind(id)
  68. .fetch_one(pool)
  69. .await
  70. .ok()?;
  71. //lets just get all the tokens the user can use, we will only use the full permissions if modifying them.
  72. let sql_user_perms = sqlx::query_as::<_, SqlPermissionTokens>(
  73. "SELECT token FROM user_permissions WHERE user_id = $1;",
  74. )
  75. .bind(id)
  76. .fetch_all(pool)
  77. .await
  78. .ok()?;
  79. Some(sqluser.into_user(Some(sql_user_perms)))
  80. }
  81. pub async fn create_user_tables(pool: &SqlitePool) {
  82. sqlx::query(
  83. r#"
  84. CREATE TABLE IF NOT EXISTS users (
  85. "id" INTEGER PRIMARY KEY,
  86. "anonymous" BOOLEAN NOT NULL,
  87. "username" VARCHAR(256) NOT NULL
  88. )
  89. "#,
  90. )
  91. .execute(pool)
  92. .await
  93. .unwrap();
  94. sqlx::query(
  95. r#"
  96. CREATE TABLE IF NOT EXISTS user_permissions (
  97. "user_id" INTEGER NOT NULL,
  98. "token" VARCHAR(256) NOT NULL
  99. )
  100. "#,
  101. )
  102. .execute(pool)
  103. .await
  104. .unwrap();
  105. sqlx::query(
  106. r#"
  107. INSERT INTO users
  108. (id, anonymous, username) SELECT 1, true, 'Guest'
  109. ON CONFLICT(id) DO UPDATE SET
  110. anonymous = EXCLUDED.anonymous,
  111. username = EXCLUDED.username
  112. "#,
  113. )
  114. .execute(pool)
  115. .await
  116. .unwrap();
  117. sqlx::query(
  118. r#"
  119. INSERT INTO users
  120. (id, anonymous, username) SELECT 2, false, 'Test'
  121. ON CONFLICT(id) DO UPDATE SET
  122. anonymous = EXCLUDED.anonymous,
  123. username = EXCLUDED.username
  124. "#,
  125. )
  126. .execute(pool)
  127. .await
  128. .unwrap();
  129. sqlx::query(
  130. r#"
  131. INSERT INTO user_permissions
  132. (user_id, token) SELECT 2, 'Category::View'
  133. "#,
  134. )
  135. .execute(pool)
  136. .await
  137. .unwrap();
  138. }
  139. }
  140. #[derive(sqlx::FromRow, Clone)]
  141. pub struct SqlUser {
  142. pub id: i32,
  143. pub anonymous: bool,
  144. pub username: String,
  145. }
  146. impl SqlUser {
  147. pub fn into_user(self, sql_user_perms: Option<Vec<SqlPermissionTokens>>) -> User {
  148. User {
  149. id: self.id,
  150. anonymous: self.anonymous,
  151. username: self.username,
  152. permissions: if let Some(user_perms) = sql_user_perms {
  153. user_perms
  154. .into_iter()
  155. .map(|x| x.token)
  156. .collect::<HashSet<String>>()
  157. } else {
  158. HashSet::<String>::new()
  159. },
  160. }
  161. }
  162. }
  163. pub async fn connect_to_database() -> SqlitePool {
  164. let connect_opts = SqliteConnectOptions::from_str("sqlite::memory:").unwrap();
  165. SqlitePoolOptions::new()
  166. .max_connections(5)
  167. .connect_with(connect_opts)
  168. .await
  169. .unwrap()
  170. }
  171. pub struct Session(
  172. pub axum_session_auth::AuthSession<
  173. crate::auth::User,
  174. i64,
  175. axum_session_auth::SessionSqlitePool,
  176. sqlx::SqlitePool,
  177. >,
  178. );
  179. impl std::ops::Deref for Session {
  180. type Target = axum_session_auth::AuthSession<
  181. crate::auth::User,
  182. i64,
  183. axum_session_auth::SessionSqlitePool,
  184. sqlx::SqlitePool,
  185. >;
  186. fn deref(&self) -> &Self::Target {
  187. &self.0
  188. }
  189. }
  190. impl std::ops::DerefMut for Session {
  191. fn deref_mut(&mut self) -> &mut Self::Target {
  192. &mut self.0
  193. }
  194. }
  195. #[derive(Debug)]
  196. pub struct AuthSessionLayerNotFound;
  197. impl std::fmt::Display for AuthSessionLayerNotFound {
  198. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  199. write!(f, "AuthSessionLayer was not found")
  200. }
  201. }
  202. impl std::error::Error for AuthSessionLayerNotFound {}
  203. impl IntoResponse for AuthSessionLayerNotFound {
  204. fn into_response(self) -> Response {
  205. (
  206. http::status::StatusCode::INTERNAL_SERVER_ERROR,
  207. "AuthSessionLayer was not found",
  208. )
  209. .into_response()
  210. }
  211. }
  212. #[async_trait]
  213. impl<S: std::marker::Sync + std::marker::Send> axum::extract::FromRequestParts<S> for Session {
  214. type Rejection = AuthSessionLayerNotFound;
  215. async fn from_request_parts(
  216. parts: &mut http::request::Parts,
  217. state: &S,
  218. ) -> Result<Self, Self::Rejection> {
  219. axum_session_auth::AuthSession::<
  220. crate::auth::User,
  221. i64,
  222. axum_session_auth::SessionSqlitePool,
  223. sqlx::SqlitePool,
  224. >::from_request_parts(parts, state)
  225. .await
  226. .map(Session)
  227. .map_err(|_| AuthSessionLayerNotFound)
  228. }
  229. }