kumo_server_common/http_server/
auth.rs1use crate::http_server::AppState;
2use axum::extract::{FromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use config::{load_config, CallbackSignature};
7use std::net::{IpAddr, SocketAddr};
8use tokio::time::{Duration, Instant};
9
10lruttl::declare_cache! {
11static AUTH_CACHE: LruCacheWithTtl<AuthKind, Result<bool, String>>::new("http_server_auth", 128);
13}
14
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
19pub enum AuthKind {
20 TrustedIp(IpAddr),
21 Basic {
22 user: String,
23 password: Option<String>,
24 },
25 Bearer {
26 token: String,
27 },
28}
29
30impl AuthKind {
31 pub fn from_header(authorization: &str) -> Option<Self> {
32 let (kind, contents) = authorization.split_once(' ')?;
33 match kind {
34 "Basic" => {
35 let decoded = data_encoding::BASE64.decode(contents.as_bytes()).ok()?;
36 let decoded = String::from_utf8(decoded).ok()?;
37 let (user, password) = if let Some((id, password)) = decoded.split_once(':') {
38 (id.to_string(), Some(password.to_string()))
39 } else {
40 (decoded.to_string(), None)
41 };
42 Some(Self::Basic { user, password })
43 }
44 "Bearer" => Some(Self::Bearer {
45 token: contents.to_string(),
46 }),
47 _ => None,
48 }
49 }
50
51 async fn validate_impl(&self) -> anyhow::Result<bool> {
52 let mut config = load_config().await?;
53 let result = match self {
54 Self::TrustedIp(_) => true,
55 Self::Basic { user, password } => {
56 let sig = CallbackSignature::<(String, Option<String>), bool>::new(
57 "http_server_validate_auth_basic",
58 );
59 config
60 .async_call_callback(&sig, (user.to_string(), password.clone()))
61 .await?
62 }
63 Self::Bearer { token } => {
64 let sig =
65 CallbackSignature::<String, bool>::new("http_server_validate_auth_bearer");
66 config.async_call_callback(&sig, token.to_string()).await?
67 }
68 };
69 config.put();
70 Ok(result)
71 }
72
73 async fn lookup_cache(&self) -> Option<Result<bool, String>> {
74 AUTH_CACHE.get(self)
75 }
76
77 pub async fn validate(&self) -> anyhow::Result<bool> {
78 match self.lookup_cache().await {
79 Some(res) => res.map_err(|err| anyhow::anyhow!("{err}")),
80 None => {
81 let res = self.validate_impl().await.map_err(|err| format!("{err:#}"));
82
83 let res = AUTH_CACHE
84 .insert(self.clone(), res, Instant::now() + Duration::from_secs(60))
85 .await;
86
87 res.map_err(|err| anyhow::anyhow!("{err}"))
88 }
89 }
90 }
91
92 pub fn summarize(&self) -> String {
93 match self {
94 Self::TrustedIp(addr) => addr.to_string(),
95 Self::Basic { user, .. } => user.to_string(),
96 Self::Bearer { .. } => "Bearer".to_string(),
97 }
98 }
99}
100
101fn is_auth_exempt(uri: &axum::http::Uri) -> bool {
102 match uri.path() {
103 "/api/check-liveness/v1" => true,
104 _ => false,
105 }
106}
107
108pub async fn auth_middleware(
109 State(state): State<AppState>,
110 mut request: Request,
111 next: Next,
112) -> Response {
113 if is_auth_exempt(request.uri()) {
114 return next.run(request).await;
115 }
116
117 match request.headers().get(axum::http::header::AUTHORIZATION) {
119 None => {
120 if let Some(remote_addr) = request
121 .extensions()
122 .get::<axum::extract::ConnectInfo<SocketAddr>>()
123 .map(|ci| ci.0)
124 {
125 let ip = remote_addr.ip();
126 if state.is_trusted_host(ip) {
127 request.extensions_mut().insert(AuthKind::TrustedIp(ip));
128 return next.run(request).await;
129 }
130
131 return (
132 StatusCode::UNAUTHORIZED,
133 format!("{ip} is not a trusted host, and no Authorization header is present"),
134 )
135 .into_response();
136 }
137
138 (
139 StatusCode::UNAUTHORIZED,
140 "peer is unknown, and no Authorization header is present",
141 )
142 .into_response()
143 }
144 Some(authorization) => match authorization.to_str() {
145 Err(_) => (StatusCode::BAD_REQUEST, "Malformed Authorization header").into_response(),
146 Ok(authorization) => match AuthKind::from_header(authorization) {
147 None => (
148 StatusCode::BAD_REQUEST,
149 "Malformed or unsupported Authorization header",
150 )
151 .into_response(),
152 Some(kind) => match kind.validate().await {
153 Ok(true) => {
154 request.extensions_mut().insert(kind);
156 next.run(request).await
157 }
158 Ok(false) => {
159 (StatusCode::UNAUTHORIZED, "Invalid Authorization").into_response()
160 }
161 Err(err) => {
162 tracing::error!("Error validating {kind:?}: {err:#}");
163 (StatusCode::INTERNAL_SERVER_ERROR, "try again later").into_response()
164 }
165 },
166 },
167 },
168 }
169}
170
171impl<B> FromRequestParts<B> for AuthKind
172where
173 B: Send + Sync,
174{
175 type Rejection = (StatusCode, &'static str);
176
177 async fn from_request_parts(
178 parts: &mut axum::http::request::Parts,
179 _: &B,
180 ) -> Result<Self, Self::Rejection> {
181 let kind = parts
182 .extensions
183 .get::<AuthKind>()
184 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
185
186 Ok(kind.clone())
187 }
188}
189
190pub struct TrustedIpRequired;
193
194impl<B> FromRequestParts<B> for TrustedIpRequired
195where
196 B: Send + Sync,
197{
198 type Rejection = (StatusCode, &'static str);
199
200 async fn from_request_parts(
201 parts: &mut axum::http::request::Parts,
202 _: &B,
203 ) -> Result<Self, Self::Rejection> {
204 let kind = parts
205 .extensions
206 .get::<AuthKind>()
207 .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;
208
209 match kind {
210 AuthKind::TrustedIp(_) => Ok(TrustedIpRequired),
211 _ => Err((StatusCode::UNAUTHORIZED, "Trusted IP required")),
212 }
213 }
214}