kumo_server_common/http_server/
auth.rs

1use crate::http_server::AppState;
2use axum::extract::{FromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use config::{load_config, CallbackSignature};
7use std::net::{IpAddr, SocketAddr};
8use tokio::time::{Duration, Instant};
9
10lruttl::declare_cache! {
11static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
12}
13
14/// Represents some authenticated identity.
15/// Use this as an extractor parameter when you need to reference
16/// that identity in the handler.
17#[derive(Debug, Clone, Hash, Eq, PartialEq)]
18pub enum AuthKind {
19    TrustedIp(IpAddr),
20    Basic {
21        user: String,
22        password: Option<String>,
23    },
24    Bearer {
25        token: String,
26    },
27}
28
29impl AuthKind {
30    pub fn from_header(authorization: &str) -> Option<Self> {
31        let (kind, contents) = authorization.split_once(' ')?;
32        match kind {
33            "Basic" => {
34                let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
35                let decoded = String::from_utf8(decoded).ok()?;
36                let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
37                    (id.to_string(), Some(password.to_string()))
38                } else {
39                    (decoded.to_string(), None)
40                };
41                Some(Self::Basic { user, password })
42            }
43            "Bearer" => Some(Self::Bearer {
44                token: contents.to_string(),
45            }),
46            _ => None,
47        }
48    }
49
50    async fn validate_impl(&self) -> anyhow::Result<bool> {
51        let mut config = load_config().await?;
52        let result = match self {
53            Self::TrustedIp(_) => true,
54            Self::Basic { user, password } => {
55                let sig = CallbackSignature::<(String, Option<String>), bool>::new(
56                    "http_server_validate_auth_basic",
57                );
58                config
59                    .async_call_callback(&sig, (user.to_string(), password.clone()))
60                    .await?
61            }
62            Self::Bearer { token } => {
63                let sig =
64                    CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
65                config.async_call_callback(&sig, token.to_string()).await?
66            }
67        };
68        config.put();
69        Ok(result)
70    }
71
72    async fn lookup_cache(&self) -> Option<Result<bool, String>> {
73        AUTH_CACHE.get(self)
74    }
75
76    pub async fn validate(&self) -> anyhow::Result<bool> {
77        match self.lookup_cache().await {
78            Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
79            None => {
80                let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
81
82                let res = AUTH_CACHE
83                    .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
84                    .await;
85
86                res.map_err(|err| anyhow::anyhow!("{err}"))
87            }
88        }
89    }
90
91    pub fn summarize(&self) -> String {
92        match self {
93            Self::TrustedIp(addr) => addr.to_string(),
94            Self::Basic { user, .. } => user.to_string(),
95            Self::Bearer { .. } => "Bearer".to_string(),
96        }
97    }
98}
99
100fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
101    match uri.path() {
102        "/api/check-liveness/v1" => true,
103        _ => false,
104    }
105}
106
107pub async fn auth_middleware(
108    State(state): State<AppState>,
109    mut request: Request,
110    next: Next,
111) -> Response {
112    if is_auth_exempt(request.uri()) {
113        return next.run(request).await;
114    }
115
116    // Get authorization header
117    match request.headers().get(axum::http::header::AUTHORIZATION) {
118        None => {
119            if let Some(remote_addr) = request
120                .extensions()
121                .get::<axum::extract::ConnectInfo<SocketAddr>>()
122                .map(|ci| ci.0)
123            {
124                let ip = remote_addr.ip();
125                if state.is_trusted_host(ip) {
126                    request.extensions_mut().insert(AuthKind::TrustedIp(ip));
127                    return next.run(request).await;
128                }
129
130                return (
131                    StatusCode::UNAUTHORIZED,
132                    format!("{ip} is not a trusted host, and no Authorization header is present"),
133                )
134                    .into_response();
135            }
136
137            (
138                StatusCode::UNAUTHORIZED,
139                "peer is unknown, and no Authorization header is present",
140            )
141                .into_response()
142        }
143        Some(authorization) => match authorization.to_str() {
144            Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
145            Ok(authorization) => match AuthKind::from_header(authorization) {
146                None => (
147                    StatusCode::BAD_REQUEST,
148                    "Malformed or unsupported Authorization header",
149                )
150                    .into_response(),
151                Some(kind) => match kind.validate().await {
152                    Ok(true) => {
153                        // Store the authentication inform for later retrieval
154                        request.extensions_mut().insert(kind);
155                        next.run(request).await
156                    }
157                    Ok(false) => {
158                        (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
159                    }
160                    Err(err) => {
161                        tracing::error!("Error validating {kind:?}: {err:#}");
162                        (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
163                    }
164                },
165            },
166        },
167    }
168}
169
170impl<B> FromRequestParts<B> for AuthKind
171where
172    B: Send + Sync,
173{
174    type Rejection = (StatusCode, &'static str);
175
176    async fn from_request_parts(
177        parts: &mut axum::http::request::Parts,
178        _: &B,
179    ) -> Result<Self, Self::Rejection> {
180        let kind = parts
181            .extensions
182            .get::<AuthKind>()
183            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
184
185        Ok(kind.clone())
186    }
187}
188
189/// Use this type as an extractor parameter when the handler must
190/// only be accessible to trusted IP addresses
191pub struct TrustedIpRequired;
192
193impl<B> FromRequestParts<B> for TrustedIpRequired
194where
195    B: Send + Sync,
196{
197    type Rejection = (StatusCode, &'static str);
198
199    async fn from_request_parts(
200        parts: &mut axum::http::request::Parts,
201        _: &B,
202    ) -> Result<Self, Self::Rejection> {
203        let kind = parts
204            .extensions
205            .get::<AuthKind>()
206            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
207
208        match kind {
209            AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
210            _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
211        }
212    }
213}