kumo_server_common/http_server/
auth.rs1use crate::http_server::AppState;
2use axum::async_trait;
3use axum::extract::{FromRequestParts, Request, State};
4use axum::http::StatusCode;
5use axum::middleware::Next;
6use axum::response::{IntoResponse, Response};
7use config::{load_config, CallbackSignature};
8use std::net::{IpAddr, SocketAddr};
9use tokio::time::{Duration, Instant};
10
11lruttl::declare_cache! {
12static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
13}
14
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
19pub enum AuthKind {
20 TrustedIp(IpAddr),
21 Basic {
22 user: String,
23 password: Option<String>,
24 },
25 Bearer {
26 token: String,
27 },
28}
29
30impl AuthKind {
31 pub fn from_header(authorization: &str) -> Option<Self> {
32 let (kind, contents) = authorization.split_once(' ')?;
33 match kind {
34 "Basic" => {
35 let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
36 let decoded = String::from_utf8(decoded).ok()?;
37 let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
38 (id.to_string(), Some(password.to_string()))
39 } else {
40 (decoded.to_string(), None)
41 };
42 Some(Self::Basic { user, password })
43 }
44 "Bearer" => Some(Self::Bearer {
45 token: contents.to_string(),
46 }),
47 _ => None,
48 }
49 }
50
51 async fn validate_impl(&self) -> anyhow::Result<bool> {
52 let mut config = load_config().await?;
53 let result = match self {
54 Self::TrustedIp(_) => true,
55 Self::Basic { user, password } => {
56 let sig = CallbackSignature::<(String, Option<String>), bool>::new(
57 "http_server_validate_auth_basic",
58 );
59 config
60 .async_call_callback(&sig, (user.to_string(), password.clone()))
61 .await?
62 }
63 Self::Bearer { token } => {
64 let sig =
65 CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
66 config.async_call_callback(&sig, token.to_string()).await?
67 }
68 };
69 config.put();
70 Ok(result)
71 }
72
73 async fn lookup_cache(&self) -> Option<Result<bool, String>> {
74 AUTH_CACHE.get(self)
75 }
76
77 pub async fn validate(&self) -> anyhow::Result<bool> {
78 match self.lookup_cache().await {
79 Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
80 None => {
81 let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
82
83 let res = AUTH_CACHE
84 .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
85 .await;
86
87 res.map_err(|err| anyhow::anyhow!("{err}"))
88 }
89 }
90 }
91
92 pub fn summarize(&self) -> String {
93 match self {
94 Self::TrustedIp(addr) => addr.to_string(),
95 Self::Basic { user, .. } => user.to_string(),
96 Self::Bearer { .. } => "Bearer".to_string(),
97 }
98 }
99}
100
101fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
102 match uri.path() {
103 "/api/check-liveness/v1" => true,
104 _ => false,
105 }
106}
107
108pub async fn auth_middleware(
109 State(state): State<AppState>,
110 mut request: Request,
111 next: Next,
112) -> Response {
113 if is_auth_exempt(request.uri()) {
114 return next.run(request).await;
115 }
116
117 match request.headers().get(axum::http::header::AUTHORIZATION) {
119 None => {
120 if let Some(remote_addr) = request
121 .extensions()
122 .get::<axum::extract::ConnectInfo<SocketAddr>>()
123 .map(|ci| ci.0)
124 {
125 let ip = remote_addr.ip();
126 if state.is_trusted_host(ip) {
127 request.extensions_mut().insert(AuthKind::TrustedIp(ip));
128 return next.run(request).await;
129 }
130 }
131
132 (StatusCode::UNAUTHORIZED, "Missing Authorization header").into_response()
133 }
134 Some(authorization) => match authorization.to_str() {
135 Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
136 Ok(authorization) => match AuthKind::from_header(authorization) {
137 None => (
138 StatusCode::BAD_REQUEST,
139 "Malformed or unsupported Authorization header",
140 )
141 .into_response(),
142 Some(kind) => match kind.validate().await {
143 Ok(true) => {
144 request.extensions_mut().insert(kind);
146 next.run(request).await
147 }
148 Ok(false) => {
149 (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
150 }
151 Err(err) => {
152 tracing::error!("Error validating {kind:?}: {err:#}");
153 (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
154 }
155 },
156 },
157 },
158 }
159}
160
161#[async_trait]
162impl<B> FromRequestParts<B> for AuthKind
163where
164 B: Send + Sync,
165{
166 type Rejection = (StatusCode, &'static str);
167
168 async fn from_request_parts(
169 parts: &mut axum::http::request::Parts,
170 _: &B,
171 ) -> Result<Self, Self::Rejection> {
172 let kind = parts
173 .extensions
174 .get::<AuthKind>()
175 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
176
177 Ok(kind.clone())
178 }
179}
180
181pub struct TrustedIpRequired;
184
185#[async_trait]
186impl<B> FromRequestParts<B> for TrustedIpRequired
187where
188 B: Send + Sync,
189{
190 type Rejection = (StatusCode, &'static str);
191
192 async fn from_request_parts(
193 parts: &mut axum::http::request::Parts,
194 _: &B,
195 ) -> Result<Self, Self::Rejection> {
196 let kind = parts
197 .extensions
198 .get::<AuthKind>()
199 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
200
201 match kind {
202 AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
203 _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
204 }
205 }
206}