kumo_server_common/http_server/
auth.rs1use crate::http_server::AppState;
2use axum::extract::{FromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use config::{load_config, CallbackSignature};
7use std::net::{IpAddr, SocketAddr};
8use tokio::time::{Duration, Instant};
9
10lruttl::declare_cache! {
11static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
12}
13
14#[derive(Debug, Clone, Hash, Eq, PartialEq)]
18pub enum AuthKind {
19 TrustedIp(IpAddr),
20 Basic {
21 user: String,
22 password: Option<String>,
23 },
24 Bearer {
25 token: String,
26 },
27}
28
29impl AuthKind {
30 pub fn from_header(authorization: &str) -> Option<Self> {
31 let (kind, contents) = authorization.split_once(' ')?;
32 match kind {
33 "Basic" => {
34 let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
35 let decoded = String::from_utf8(decoded).ok()?;
36 let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
37 (id.to_string(), Some(password.to_string()))
38 } else {
39 (decoded.to_string(), None)
40 };
41 Some(Self::Basic { user, password })
42 }
43 "Bearer" => Some(Self::Bearer {
44 token: contents.to_string(),
45 }),
46 _ => None,
47 }
48 }
49
50 async fn validate_impl(&self) -> anyhow::Result<bool> {
51 let mut config = load_config().await?;
52 let result = match self {
53 Self::TrustedIp(_) => true,
54 Self::Basic { user, password } => {
55 let sig = CallbackSignature::<(String, Option<String>), bool>::new(
56 "http_server_validate_auth_basic",
57 );
58 config
59 .async_call_callback(&sig, (user.to_string(), password.clone()))
60 .await?
61 }
62 Self::Bearer { token } => {
63 let sig =
64 CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
65 config.async_call_callback(&sig, token.to_string()).await?
66 }
67 };
68 config.put();
69 Ok(result)
70 }
71
72 async fn lookup_cache(&self) -> Option<Result<bool, String>> {
73 AUTH_CACHE.get(self)
74 }
75
76 pub async fn validate(&self) -> anyhow::Result<bool> {
77 match self.lookup_cache().await {
78 Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
79 None => {
80 let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
81
82 let res = AUTH_CACHE
83 .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
84 .await;
85
86 res.map_err(|err| anyhow::anyhow!("{err}"))
87 }
88 }
89 }
90
91 pub fn summarize(&self) -> String {
92 match self {
93 Self::TrustedIp(addr) => addr.to_string(),
94 Self::Basic { user, .. } => user.to_string(),
95 Self::Bearer { .. } => "Bearer".to_string(),
96 }
97 }
98}
99
100fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
101 match uri.path() {
102 "/api/check-liveness/v1" => true,
103 _ => false,
104 }
105}
106
107pub async fn auth_middleware(
108 State(state): State<AppState>,
109 mut request: Request,
110 next: Next,
111) -> Response {
112 if is_auth_exempt(request.uri()) {
113 return next.run(request).await;
114 }
115
116 match request.headers().get(axum::http::header::AUTHORIZATION) {
118 None => {
119 if let Some(remote_addr) = request
120 .extensions()
121 .get::<axum::extract::ConnectInfo<SocketAddr>>()
122 .map(|ci| ci.0)
123 {
124 let ip = remote_addr.ip();
125 if state.is_trusted_host(ip) {
126 request.extensions_mut().insert(AuthKind::TrustedIp(ip));
127 return next.run(request).await;
128 }
129
130 return (
131 StatusCode::UNAUTHORIZED,
132 format!("{ip} is not a trusted host, and no Authorization header is present"),
133 )
134 .into_response();
135 }
136
137 (
138 StatusCode::UNAUTHORIZED,
139 "peer is unknown, and no Authorization header is present",
140 )
141 .into_response()
142 }
143 Some(authorization) => match authorization.to_str() {
144 Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
145 Ok(authorization) => match AuthKind::from_header(authorization) {
146 None => (
147 StatusCode::BAD_REQUEST,
148 "Malformed or unsupported Authorization header",
149 )
150 .into_response(),
151 Some(kind) => match kind.validate().await {
152 Ok(true) => {
153 request.extensions_mut().insert(kind);
155 next.run(request).await
156 }
157 Ok(false) => {
158 (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
159 }
160 Err(err) => {
161 tracing::error!("Error validating {kind:?}: {err:#}");
162 (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
163 }
164 },
165 },
166 },
167 }
168}
169
170impl<B> FromRequestParts<B> for AuthKind
171where
172 B: Send + Sync,
173{
174 type Rejection = (StatusCode, &'static str);
175
176 async fn from_request_parts(
177 parts: &mut axum::http::request::Parts,
178 _: &B,
179 ) -> Result<Self, Self::Rejection> {
180 let kind = parts
181 .extensions
182 .get::<AuthKind>()
183 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
184
185 Ok(kind.clone())
186 }
187}
188
189pub struct TrustedIpRequired;
192
193impl<B> FromRequestParts<B> for TrustedIpRequired
194where
195 B: Send + Sync,
196{
197 type Rejection = (StatusCode, &'static str);
198
199 async fn from_request_parts(
200 parts: &mut axum::http::request::Parts,
201 _: &B,
202 ) -> Result<Self, Self::Rejection> {
203 let kind = parts
204 .extensions
205 .get::<AuthKind>()
206 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
207
208 match kind {
209 AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
210 _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
211 }
212 }
213}