kumo_server_common/http_server/
auth.rs

1use crate::http_server::AppState;
2use axum::async_trait;
3use axum::extract::{FromRequestParts, Request, State};
4use axum::http::StatusCode;
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use config::{load_config, CallbackSignature};
8use std::net::{IpAddr, SocketAddr};
9use tokio::time::{Duration, Instant};
10
11lruttl::declare_cache! {
12static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
13}
14
15/// Represents some authenticated identity.
16/// Use this as an extractor parameter when you need to reference
17/// that identity in the handler.
18#[derive(Debug, Clone, Hash, Eq, PartialEq)]
19pub enum AuthKind {
20    TrustedIp(IpAddr),
21    Basic {
22        user: String,
23        password: Option<String>,
24    },
25    Bearer {
26        token: String,
27    },
28}
29
30impl AuthKind {
31    pub fn from_header(authorization: &str) -> Option<Self> {
32        let (kind, contents) = authorization.split_once(' ')?;
33        match kind {
34            "Basic" => {
35                let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
36                let decoded = String::from_utf8(decoded).ok()?;
37                let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
38                    (id.to_string(), Some(password.to_string()))
39                } else {
40                    (decoded.to_string(), None)
41                };
42                Some(Self::Basic { user, password })
43            }
44            "Bearer" => Some(Self::Bearer {
45                token: contents.to_string(),
46            }),
47            _ => None,
48        }
49    }
50
51    async fn validate_impl(&self) -> anyhow::Result<bool> {
52        let mut config = load_config().await?;
53        let result = match self {
54            Self::TrustedIp(_) => true,
55            Self::Basic { user, password } => {
56                let sig = CallbackSignature::<(String, Option<String>), bool>::new(
57                    "http_server_validate_auth_basic",
58                );
59                config
60                    .async_call_callback(&sig, (user.to_string(), password.clone()))
61                    .await?
62            }
63            Self::Bearer { token } => {
64                let sig =
65                    CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
66                config.async_call_callback(&sig, token.to_string()).await?
67            }
68        };
69        config.put();
70        Ok(result)
71    }
72
73    async fn lookup_cache(&self) -> Option<Result<bool, String>> {
74        AUTH_CACHE.get(self)
75    }
76
77    pub async fn validate(&self) -> anyhow::Result<bool> {
78        match self.lookup_cache().await {
79            Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
80            None => {
81                let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
82
83                let res = AUTH_CACHE
84                    .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
85                    .await;
86
87                res.map_err(|err| anyhow::anyhow!("{err}"))
88            }
89        }
90    }
91
92    pub fn summarize(&self) -> String {
93        match self {
94            Self::TrustedIp(addr) => addr.to_string(),
95            Self::Basic { user, .. } => user.to_string(),
96            Self::Bearer { .. } => "Bearer".to_string(),
97        }
98    }
99}
100
101fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
102    match uri.path() {
103        "/api/check-liveness/v1" => true,
104        _ => false,
105    }
106}
107
108pub async fn auth_middleware(
109    State(state): State<AppState>,
110    mut request: Request,
111    next: Next,
112) -> Response {
113    if is_auth_exempt(request.uri()) {
114        return next.run(request).await;
115    }
116
117    // Get authorization header
118    match request.headers().get(axum::http::header::AUTHORIZATION) {
119        None => {
120            if let Some(remote_addr) = request
121                .extensions()
122                .get::<axum::extract::ConnectInfo<SocketAddr>>()
123                .map(|ci| ci.0)
124            {
125                let ip = remote_addr.ip();
126                if state.is_trusted_host(ip) {
127                    request.extensions_mut().insert(AuthKind::TrustedIp(ip));
128                    return next.run(request).await;
129                }
130            }
131
132            (StatusCode::UNAUTHORIZED, "Missing Authorization header").into_response()
133        }
134        Some(authorization) => match authorization.to_str() {
135            Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
136            Ok(authorization) => match AuthKind::from_header(authorization) {
137                None => (
138                    StatusCode::BAD_REQUEST,
139                    "Malformed or unsupported Authorization header",
140                )
141                    .into_response(),
142                Some(kind) => match kind.validate().await {
143                    Ok(true) => {
144                        // Store the authentication inform for later retrieval
145                        request.extensions_mut().insert(kind);
146                        next.run(request).await
147                    }
148                    Ok(false) => {
149                        (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
150                    }
151                    Err(err) => {
152                        tracing::error!("Error validating {kind:?}: {err:#}");
153                        (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
154                    }
155                },
156            },
157        },
158    }
159}
160
161#[async_trait]
162impl<B> FromRequestParts<B> for AuthKind
163where
164    B: Send + Sync,
165{
166    type Rejection = (StatusCode, &'static str);
167
168    async fn from_request_parts(
169        parts: &mut axum::http::request::Parts,
170        _: &B,
171    ) -> Result<Self, Self::Rejection> {
172        let kind = parts
173            .extensions
174            .get::<AuthKind>()
175            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
176
177        Ok(kind.clone())
178    }
179}
180
181/// Use this type as an extractor parameter when the handler must
182/// only be accessible to trusted IP addresses
183pub struct TrustedIpRequired;
184
185#[async_trait]
186impl<B> FromRequestParts<B> for TrustedIpRequired
187where
188    B: Send + Sync,
189{
190    type Rejection = (StatusCode, &'static str);
191
192    async fn from_request_parts(
193        parts: &mut axum::http::request::Parts,
194        _: &B,
195    ) -> Result<Self, Self::Rejection> {
196        let kind = parts
197            .extensions
198            .get::<AuthKind>()
199            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
200
201        match kind {
202            AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
203            _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
204        }
205    }
206}