kumo_server_common/http_server/
auth.rs

1use crate::http_server::AppState;
2use axum::extract::{FromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use config::{load_config, CallbackSignature};
7use std::net::{IpAddr, SocketAddr};
8use tokio::time::{Duration, Instant};
9
10lruttl::declare_cache! {
11/// Caches the results of the http server auth validation by auth credential
12static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
13}
14
15/// Represents some authenticated identity.
16/// Use this as an extractor parameter when you need to reference
17/// that identity in the handler.
18#[derive(Debug, Clone, Hash, Eq, PartialEq)]
19pub enum AuthKind {
20    TrustedIp(IpAddr),
21    Basic {
22        user: String,
23        password: Option<String>,
24    },
25    Bearer {
26        token: String,
27    },
28}
29
30impl AuthKind {
31    pub fn from_header(authorization: &str) -> Option<Self> {
32        let (kind, contents) = authorization.split_once(' ')?;
33        match kind {
34            "Basic" => {
35                let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
36                let decoded = String::from_utf8(decoded).ok()?;
37                let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
38                    (id.to_string(), Some(password.to_string()))
39                } else {
40                    (decoded.to_string(), None)
41                };
42                Some(Self::Basic { user, password })
43            }
44            "Bearer" => Some(Self::Bearer {
45                token: contents.to_string(),
46            }),
47            _ => None,
48        }
49    }
50
51    async fn validate_impl(&self) -> anyhow::Result<bool> {
52        let mut config = load_config().await?;
53        let result = match self {
54            Self::TrustedIp(_) => true,
55            Self::Basic { user, password } => {
56                let sig = CallbackSignature::<(String, Option<String>), bool>::new(
57                    "http_server_validate_auth_basic",
58                );
59                config
60                    .async_call_callback(&sig, (user.to_string(), password.clone()))
61                    .await?
62            }
63            Self::Bearer { token } => {
64                let sig =
65                    CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
66                config.async_call_callback(&sig, token.to_string()).await?
67            }
68        };
69        config.put();
70        Ok(result)
71    }
72
73    async fn lookup_cache(&self) -> Option<Result<bool, String>> {
74        AUTH_CACHE.get(self)
75    }
76
77    pub async fn validate(&self) -> anyhow::Result<bool> {
78        match self.lookup_cache().await {
79            Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
80            None => {
81                let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
82
83                let res = AUTH_CACHE
84                    .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
85                    .await;
86
87                res.map_err(|err| anyhow::anyhow!("{err}"))
88            }
89        }
90    }
91
92    pub fn summarize(&self) -> String {
93        match self {
94            Self::TrustedIp(addr) => addr.to_string(),
95            Self::Basic { user, .. } => user.to_string(),
96            Self::Bearer { .. } => "Bearer".to_string(),
97        }
98    }
99}
100
101fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
102    match uri.path() {
103        "/api/check-liveness/v1" => true,
104        _ => false,
105    }
106}
107
108pub async fn auth_middleware(
109    State(state): State<AppState>,
110    mut request: Request,
111    next: Next,
112) -> Response {
113    if is_auth_exempt(request.uri()) {
114        return next.run(request).await;
115    }
116
117    // Get authorization header
118    match request.headers().get(axum::http::header::AUTHORIZATION) {
119        None => {
120            if let Some(remote_addr) = request
121                .extensions()
122                .get::<axum::extract::ConnectInfo<SocketAddr>>()
123                .map(|ci| ci.0)
124            {
125                let ip = remote_addr.ip();
126                if state.is_trusted_host(ip) {
127                    request.extensions_mut().insert(AuthKind::TrustedIp(ip));
128                    return next.run(request).await;
129                }
130
131                return (
132                    StatusCode::UNAUTHORIZED,
133                    format!("{ip} is not a trusted host, and no Authorization header is present"),
134                )
135                    .into_response();
136            }
137
138            (
139                StatusCode::UNAUTHORIZED,
140                "peer is unknown, and no Authorization header is present",
141            )
142                .into_response()
143        }
144        Some(authorization) => match authorization.to_str() {
145            Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
146            Ok(authorization) => match AuthKind::from_header(authorization) {
147                None => (
148                    StatusCode::BAD_REQUEST,
149                    "Malformed or unsupported Authorization header",
150                )
151                    .into_response(),
152                Some(kind) => match kind.validate().await {
153                    Ok(true) => {
154                        // Store the authentication inform for later retrieval
155                        request.extensions_mut().insert(kind);
156                        next.run(request).await
157                    }
158                    Ok(false) => {
159                        (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
160                    }
161                    Err(err) => {
162                        tracing::error!("Error validating {kind:?}: {err:#}");
163                        (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
164                    }
165                },
166            },
167        },
168    }
169}
170
171impl<B> FromRequestParts<B> for AuthKind
172where
173    B: Send + Sync,
174{
175    type Rejection = (StatusCode, &'static str);
176
177    async fn from_request_parts(
178        parts: &mut axum::http::request::Parts,
179        _: &B,
180    ) -> Result<Self, Self::Rejection> {
181        let kind = parts
182            .extensions
183            .get::<AuthKind>()
184            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
185
186        Ok(kind.clone())
187    }
188}
189
190/// Use this type as an extractor parameter when the handler must
191/// only be accessible to trusted IP addresses
192pub struct TrustedIpRequired;
193
194impl<B> FromRequestParts<B> for TrustedIpRequired
195where
196    B: Send + Sync,
197{
198    type Rejection = (StatusCode, &'static str);
199
200    async fn from_request_parts(
201        parts: &mut axum::http::request::Parts,
202        _: &B,
203    ) -> Result<Self, Self::Rejection> {
204        let kind = parts
205            .extensions
206            .get::<AuthKind>()
207            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
208
209        match kind {
210            AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
211            _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
212        }
213    }
214}