From a5ffb69ccdc6b5b854ef3401e4686f8e12a76380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Charlotte=20=F0=9F=A6=9D=20Delenk?= Date: Wed, 4 Dec 2024 11:47:33 +0100 Subject: [PATCH] add authentication header support --- Cargo.lock | 1 + Cargo.nix | 6 +- chir-rs-db/Cargo.toml | 1 + chir-rs-db/src/session.rs | 46 +++++++++- chir-rs-http-api/src/auth/mod.rs | 11 +++ chir-rs-http-api/src/errors/mod.rs | 23 ++++- chir-rs-http/Cargo.toml | 5 +- chir-rs-http/src/auth/mod.rs | 1 + chir-rs-http/src/auth/req_auth/auth_header.rs | 83 +++++++++++++++++++ chir-rs-http/src/auth/req_auth/mod.rs | 3 + src/main.rs | 40 +++++---- 11 files changed, 196 insertions(+), 24 deletions(-) create mode 100644 chir-rs-http/src/auth/req_auth/auth_header.rs create mode 100644 chir-rs-http/src/auth/req_auth/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 381bace..0f2419c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -987,6 +987,7 @@ dependencies = [ "chir-rs-http-api", "chir-rs-misc", "eyre", + "futures", "mime", "rand", "serde", diff --git a/Cargo.nix b/Cargo.nix index 77d057c..a543310 100644 --- a/Cargo.nix +++ b/Cargo.nix @@ -32,7 +32,7 @@ args@{ ignoreLockHash, }: let - nixifiedLockHash = "11a64cfe7de901d3281b189f031b32ae0457044c4134a845d433d25e040d5018"; + nixifiedLockHash = "7dc55944ae2b0d1b65a004e3be3f5314339c3d2e023e794d54b030b197c5bf87"; workspaceSrc = if args.workspaceSrc == null then ./. else args.workspaceSrc; currentLockHash = builtins.hashFile "sha256" (workspaceSrc + /Cargo.lock); lockHashIgnored = @@ -3074,6 +3074,10 @@ else (rustPackages."registry+https://github.com/rust-lang/crates.io-index".eyre."0.6.12" { inherit profileName; }).out; + futures = + (rustPackages."registry+https://github.com/rust-lang/crates.io-index".futures."0.3.31" { + inherit profileName; + }).out; mime = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".mime."0.3.17" { inherit profileName; diff --git a/chir-rs-db/Cargo.toml b/chir-rs-db/Cargo.toml index 406e7f8..3b3cf70 100644 --- a/chir-rs-db/Cargo.toml +++ b/chir-rs-db/Cargo.toml @@ -15,6 +15,7 @@ mime = "0.3.17" chir-rs-http-api = { version = "0.1.0", path = "../chir-rs-http-api" } chir-rs-misc = { version = "0.1.0", path = "../chir-rs-misc", features = ["id-generator"] } rand = "0.8.5" +futures = "0.3.31" [lints.rust] deprecated-safe = "forbid" diff --git a/chir-rs-db/src/session.rs b/chir-rs-db/src/session.rs index 6bb66d9..da37de0 100644 --- a/chir-rs-db/src/session.rs +++ b/chir-rs-db/src/session.rs @@ -1,10 +1,12 @@ //! Session-related functionality -use std::time::Duration; +use std::{collections::HashSet, time::Duration}; use crate::Database; +use chir_rs_http_api::auth::Scope; use chir_rs_misc::id_generator; use eyre::Result; +use futures::StreamExt as _; use rand::{thread_rng, Rng as _}; use sqlx::query; use tracing::{error, info, instrument}; @@ -19,6 +21,7 @@ pub async fn expire(db: &Database) -> Result<()> { let id = id_generator::generate(); let oldest_acceptable_id = id - ((24 * 60 * 60) << 48); let oldest_acceptable_id = oldest_acceptable_id.to_be_bytes(); + #[expect(clippy::panic, reason = "sqlx moment")] query!( r#"DELETE FROM "session_scopes" WHERE session_id < $1"#, &oldest_acceptable_id @@ -28,6 +31,47 @@ pub async fn expire(db: &Database) -> Result<()> { Ok(()) } +/// Returns username and scopes for a session ID +/// +/// # Errors +/// This function returns an error if accessing the database fails +#[instrument] +#[expect(clippy::panic, reason = "sqlx moment")] +pub async fn fetch_session_info( + db: &Database, + session_id: u128, +) -> Result)>> { + let session_id = session_id.to_be_bytes(); + let Some(username_record) = query!( + r#" + SELECT "user".username FROM "user" + INNER JOIN "sessions" + ON "sessions".user_id = "user".id + WHERE "sessions".id = $1 + "#, + &session_id + ) + .fetch_optional(&*db.0) + .await? + else { + return Ok(None); + }; + + let mut scopes = HashSet::with_capacity(4); + + let mut scopes_records = query!( + "SELECT scope FROM session_scopes WHERE session_id = $1", + &session_id + ) + .fetch(&*db.0); + + while let Some(scope_record) = scopes_records.next().await { + scopes.insert(Scope::from_i64(scope_record?.scope)?); + } + + Ok(Some((username_record.username, scopes))) +} + /// Automatically expires outdated sessions /// /// This is intended to be called on a dedicated job. diff --git a/chir-rs-http-api/src/auth/mod.rs b/chir-rs-http-api/src/auth/mod.rs index 2251a9d..1d48844 100644 --- a/chir-rs-http-api/src/auth/mod.rs +++ b/chir-rs-http-api/src/auth/mod.rs @@ -21,6 +21,17 @@ impl Scope { Self::Full => 0, } } + + /// Converts a scope ID to the scope + /// + /// # Errors + /// This function returns an error if the scope ID is invalid. + pub fn from_i64(id: i64) -> Result { + match id { + 0 => Ok(Self::Full), + _ => bail!("Invalid scope ID {id}"), + } + } } /// Login request for the user diff --git a/chir-rs-http-api/src/errors/mod.rs b/chir-rs-http-api/src/errors/mod.rs index 8c72d85..0118e56 100644 --- a/chir-rs-http-api/src/errors/mod.rs +++ b/chir-rs-http-api/src/errors/mod.rs @@ -42,6 +42,21 @@ pub enum APIError { /// Invalid password #[error("Invalid password for user {0}")] InvalidPassword(String), + /// Missing authorization header + #[error("Missing authorization header")] + MissingAuthorizationHeader, + /// Invalid Authorization header value + #[error("Invalid authorization header: {0}")] + InvalidAuthorizationHeader(String), + /// Invalid authorization method + #[error("Invalid authorization method: {0}, expected {1}")] + InvalidAuthorizationMethod(String, String), + /// Unauthorized + #[error("Unauthorized")] + Unauthorized, + /// Invalid session + #[error("Invalid session")] + InvalidSession, } impl APIError { @@ -54,7 +69,13 @@ impl APIError { } Self::PayloadTooBig => StatusCode::PAYLOAD_TOO_LARGE, Self::PayloadLoadError | Self::PayloadInvalid => StatusCode::BAD_REQUEST, - Self::UserNotFound(_) | Self::InvalidPassword(_) => StatusCode::UNAUTHORIZED, + Self::UserNotFound(_) + | Self::InvalidPassword(_) + | Self::MissingAuthorizationHeader + | Self::InvalidAuthorizationHeader(_) + | Self::InvalidAuthorizationMethod(_, _) + | Self::Unauthorized + | Self::InvalidSession => StatusCode::UNAUTHORIZED, Self::Unknown(_) | Self::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/chir-rs-http/Cargo.toml b/chir-rs-http/Cargo.toml index 5fc9fa2..afa85bc 100644 --- a/chir-rs-http/Cargo.toml +++ b/chir-rs-http/Cargo.toml @@ -21,10 +21,7 @@ chir-rs-misc = { version = "0.1.0", path = "../chir-rs-misc", features = [ chrono = "0.4.38" eyre = "0.6.12" mime = "0.3.17" -rusty_paseto = { version = "0.7.1", default-features = false, features = [ - "batteries_included", - "v4_local", -] } +rusty_paseto = { version = "0.7.1", default-features = false, features = ["batteries_included", "v4_local"] } sentry-tower = { version = "0.34.0", features = ["axum", "axum-matched-path"] } tokio = { version = "1.41.1", features = ["fs", "net"] } tower-http = { version = "0.6.2", features = ["trace"] } diff --git a/chir-rs-http/src/auth/mod.rs b/chir-rs-http/src/auth/mod.rs index 07a6281..3f3228f 100644 --- a/chir-rs-http/src/auth/mod.rs +++ b/chir-rs-http/src/auth/mod.rs @@ -1,3 +1,4 @@ //! Authentication related functionality pub mod password_login; +pub mod req_auth; diff --git a/chir-rs-http/src/auth/req_auth/auth_header.rs b/chir-rs-http/src/auth/req_auth/auth_header.rs new file mode 100644 index 0000000..a00198e --- /dev/null +++ b/chir-rs-http/src/auth/req_auth/auth_header.rs @@ -0,0 +1,83 @@ +//! Authentication header handler + +use std::collections::HashSet; + +use axum::{ + async_trait, + extract::FromRequestParts, + http::{header::AUTHORIZATION, request::Parts}, +}; +use chir_rs_db::session::fetch_session_info; +use chir_rs_http_api::{auth::Scope, errors::APIError}; +use eyre::{Context as _, OptionExt as _}; +use rusty_paseto::core::{Local, V4}; +use rusty_paseto::prelude::PasetoParser; +use tracing::{error, info}; + +use crate::AppState; + +/// Read Authorization from the bearer token. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AuthHeader(pub String, pub HashSet); + +#[async_trait] +impl FromRequestParts for AuthHeader { + type Rejection = APIError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let Some(authorization_header) = parts.headers.get(AUTHORIZATION) else { + return Err(APIError::MissingAuthorizationHeader); + }; + let authorization_header = authorization_header + .to_str() + .context("Parsing the authorization header") + .map_err(|e| APIError::InvalidAuthorizationHeader(format!("{e:?}")))?; + + let Some((method, key)) = authorization_header.split_once(' ') else { + return Err(APIError::InvalidAuthorizationHeader( + authorization_header.to_string(), + )); + }; + + if !method.trim().eq_ignore_ascii_case("Bearer") { + return Err(APIError::InvalidAuthorizationMethod( + method.trim().to_string(), + "Bearer".to_string(), + )); + } + + let json = PasetoParser::::default() + .parse(key.trim(), &state.paseto_key) + .context("Verifying paseto token") + .map_err(|e| { + info!("Failed authentication with: {e:?}"); + APIError::Unauthorized + })?; + + let session_id: u128 = json["jti"] + .as_str() + .ok_or_eyre("Reading the token ID as a string") + .and_then(|v| v.parse().context("Parsing session ID")) + .map_err(|e| { + error!("Invalid issued token: {e:?}"); + APIError::Unknown(format!("Invalid issued token: {e:?}")) + })?; + + let session_info = fetch_session_info(&state.db, session_id) + .await + .with_context(|| format!("Verifying session {session_id}")) + .map_err(|e| { + error!("Failed to verify session: {e:?}"); + APIError::Unknown(format!("Failed to verify session: {e:?}")) + })? + .ok_or_eyre("Found session info") + .map_err(|e| { + info!("User Error validating session {e:?}"); + APIError::InvalidSession + })?; + Ok(Self(session_info.0, session_info.1)) + } +} diff --git a/chir-rs-http/src/auth/req_auth/mod.rs b/chir-rs-http/src/auth/req_auth/mod.rs new file mode 100644 index 0000000..4ac04ed --- /dev/null +++ b/chir-rs-http/src/auth/req_auth/mod.rs @@ -0,0 +1,3 @@ +//! Request authentication + +pub mod auth_header; diff --git a/src/main.rs b/src/main.rs index 4e0add1..1ece6c8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,23 +14,8 @@ use tracing_subscriber::{ fmt::format::JsonFields, layer::SubscriberExt as _, util::SubscriberInitExt as _, Layer, }; -fn main() -> Result<()> { - color_eyre::install().ok(); - dotenvy::dotenv().ok(); - - // NO THREADS BEFORE THIS POINT - - let cfg = ChirRs::read_from_env().context("Reading chir.rs configuration")?; - - let _guard = sentry::init(sentry::ClientOptions { - dsn: cfg.logging.sentry_dsn.clone(), - release: sentry::release_name!(), - traces_sample_rate: 0.1, - attach_stacktrace: true, - debug: cfg!(debug_assertions), - ..Default::default() - }); - +/// Initializes logging for the application +fn init_logging(cfg: &ChirRs) -> Result<()> { let log_filter = tracing_subscriber::EnvFilter::from_str(&cfg.logging.log_level) .with_context(|| format!("Setting log filter to {}", cfg.logging.log_level))?; @@ -85,6 +70,27 @@ fn main() -> Result<()> { .init(); } } + Ok(()) +} + +fn main() -> Result<()> { + color_eyre::install().ok(); + dotenvy::dotenv().ok(); + + // NO THREADS BEFORE THIS POINT + + let cfg = ChirRs::read_from_env().context("Reading chir.rs configuration")?; + + let _guard = sentry::init(sentry::ClientOptions { + dsn: cfg.logging.sentry_dsn.clone(), + release: sentry::release_name!(), + traces_sample_rate: 0.1, + attach_stacktrace: true, + debug: cfg!(debug_assertions), + ..Default::default() + }); + + init_logging(&cfg)?; tokio::runtime::Builder::new_multi_thread() .enable_all()