kumo_server_common/http_server/
mod.rs1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use tower_http::compression::CompressionLayer;
17use tower_http::decompression::RequestDecompressionLayer;
18use tower_http::trace::TraceLayer;
19use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
20use utoipa::OpenApi;
21use utoipa_rapidoc::RapiDoc;
22use kumo_api_types::*;
27
28pub mod auth;
29
30use auth::*;
31
32#[derive(OpenApi)]
33#[openapi(
34 info(license(name = "Apache-2.0")),
35 paths(set_diagnostic_log_filter_v1, bump_config_epoch),
36 security(
40 ("basic_auth" = [""])
41 ),
42 components(schemas(SetDiagnosticFilterRequest)),
43 modifiers(&OptionalAuth),
44)]
45struct ApiDoc;
46
47struct OptionalAuth;
48
49impl utoipa::Modify for OptionalAuth {
50 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
51 let components = openapi
52 .components
53 .as_mut()
54 .expect("always set because we always have components above");
55 components.add_security_scheme(
57 "basic_auth",
58 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
59 );
60 }
61}
62
63#[derive(Deserialize, Clone, Debug)]
64#[serde(deny_unknown_fields)]
65pub struct HttpListenerParams {
66 #[serde(default = "HttpListenerParams::default_hostname")]
67 pub hostname: String,
68
69 #[serde(default = "HttpListenerParams::default_listen")]
70 pub listen: String,
71
72 #[serde(default)]
73 pub use_tls: bool,
74
75 #[serde(default)]
76 pub request_body_limit: Option<usize>,
77
78 #[serde(default)]
79 pub tls_certificate: Option<KeySource>,
80 #[serde(default)]
81 pub tls_private_key: Option<KeySource>,
82
83 #[serde(default = "CidrSet::default_trusted_hosts")]
84 pub trusted_hosts: CidrSet,
85}
86
87pub struct RouterAndDocs {
88 pub router: Router<AppState>,
89 pub docs: utoipa::openapi::OpenApi,
90}
91
92impl RouterAndDocs {
93 pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
94 let mut api_docs = ApiDoc::openapi();
95 api_docs.info.title = self.docs.info.title.to_string();
96 api_docs.merge(self.docs.clone());
97 api_docs.info.version = version_info::kumo_version().to_string();
98 api_docs.info.license = Some(
99 utoipa::openapi::LicenseBuilder::new()
100 .name("Apache-2.0")
101 .build(),
102 );
103
104 api_docs
105 }
106}
107
108#[derive(Clone)]
109pub struct AppState {
110 params: HttpListenerParams,
111 local_addr: SocketAddr,
112}
113
114impl AppState {
115 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
116 self.params.trusted_hosts.contains(addr)
117 }
118
119 pub fn params(&self) -> &HttpListenerParams {
120 &self.params
121 }
122
123 pub fn local_addr(&self) -> &SocketAddr {
124 &self.local_addr
125 }
126}
127
128impl HttpListenerParams {
129 fn default_listen() -> String {
130 "127.0.0.1:8000".to_string()
131 }
132
133 fn default_hostname() -> String {
134 gethostname::gethostname()
135 .to_str()
136 .unwrap_or("localhost")
137 .to_string()
138 }
139
140 pub async fn start(
151 self,
152 router_and_docs: RouterAndDocs,
153 runtime: Option<tokio::runtime::Handle>,
154 ) -> anyhow::Result<()> {
155 let api_docs = router_and_docs.make_docs();
156
157 let compression_layer: CompressionLayer = CompressionLayer::new()
158 .deflate(true)
159 .gzip(true)
160 .quality(tower_http::CompressionLevel::Fastest);
161 let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
162
163 let socket = TcpListener::bind(&self.listen)
164 .with_context(|| format!("listen on {}", self.listen))?;
165 let addr = socket.local_addr()?;
166
167 let app_state = AppState {
168 params: self.clone(),
169 local_addr: addr.clone(),
170 };
171
172 let app = router_and_docs
173 .router
174 .layer(DefaultBodyLimit::max(
175 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
176 ))
177 .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
178 .route(
179 "/api/admin/set_diagnostic_log_filter/v1",
180 post(set_diagnostic_log_filter_v1),
181 )
182 .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
183 .route("/api/admin/memory/stats", get(memory_stats))
184 .route("/metrics", get(report_metrics))
185 .route("/metrics.json", get(report_metrics_json))
186 .route_layer(axum::middleware::from_fn_with_state(
189 app_state.clone(),
190 auth_middleware,
191 ))
192 .layer(compression_layer)
193 .layer(decompression_layer)
194 .layer(TraceLayer::new_for_http())
195 .layer(axum_client_ip::ClientIpSource::ConnectInfo.into_extension())
196 .with_state(app_state);
197
198 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
199
200 if self.use_tls {
205 let config = self.tls_config().await?;
206 tracing::info!("https listener on {addr:?}");
207 let server = axum_server::from_tcp_rustls(socket, config);
208 let serve = async move { server.serve(make_service).await };
209
210 if let Some(runtime) = runtime {
211 runtime.spawn(serve);
212 } else {
213 spawn(format!("https {addr:?}"), serve)?;
214 }
215 } else {
216 tracing::info!("http listener on {addr:?}");
217 let server = axum_server::from_tcp(socket);
218 let serve = async move { server.serve(make_service).await };
219 if let Some(runtime) = runtime {
220 runtime.spawn(serve);
221 } else {
222 spawn(format!("http {addr:?}"), serve)?;
223 }
224 }
225 Ok(())
226 }
227
228 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
229 let config = crate::tls_helpers::make_server_config(
230 &self.hostname,
231 &self.tls_private_key,
232 &self.tls_certificate,
233 &None,
234 )
235 .await?;
236 Ok(RustlsConfig::from_config(config))
237 }
238}
239
240#[derive(Debug)]
241pub struct AppError {
242 pub err: anyhow::Error,
243 pub code: StatusCode,
244}
245
246impl AppError {
247 pub fn new(code: StatusCode, err: impl Into<String>) -> Self {
248 let err: String = err.into();
249 Self {
250 err: anyhow::anyhow!(err),
251 code,
252 }
253 }
254}
255
256impl IntoResponse for AppError {
258 fn into_response(self) -> Response {
259 (self.code, format!("Error: {:#}", self.err)).into_response()
260 }
261}
262
263impl<E> From<E> for AppError
266where
267 E: Into<anyhow::Error>,
268{
269 fn from(err: E) -> Self {
270 Self {
271 err: err.into(),
272 code: StatusCode::INTERNAL_SERVER_ERROR,
273 }
274 }
275}
276
277#[utoipa::path(
281 post,
282 tag="config",
283 path="/api/admin/bump-config-epoch",
284 responses(
285 (status=200, description = "bump successful")
286 ),
287)]
288async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
289 config::epoch::bump_current_epoch();
290 Ok(())
291}
292
293#[utoipa::path(
297 get,
298 tag="memory",
299 path="/api/admin/memory/stats",
300 responses(
301 (status=200, description = "stats were returned")
302 ),
303)]
304async fn memory_stats(_: TrustedIpRequired) -> String {
305 use kumo_server_memory::NumBytes;
306 use std::fmt::Write;
307 let mut result = String::new();
308
309 let jstats = JemallocStats::collect();
310 writeln!(result, "{jstats:#?}").ok();
311
312 if let Ok((usage, limit)) = get_usage_and_limit() {
313 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
314 writeln!(
315 result,
316 "soft limit = {:?}",
317 limit.soft_limit.map(NumBytes::from)
318 )
319 .ok();
320 writeln!(
321 result,
322 "hard limit = {:?}",
323 limit.hard_limit.map(NumBytes::from)
324 )
325 .ok();
326 }
327
328 let mut stats = tracking_stats();
329 writeln!(result, "live = {:?}", stats.live).ok();
330
331 if stats.top_callstacks.is_empty() {
332 write!(
333 result,
334 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
335 )
336 .ok();
337 } else {
338 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
339 write!(result, "\ntop call stacks:\n").ok();
340 for stack in &mut stats.top_callstacks {
341 writeln!(
342 result,
343 "sampled every {} allocations, estimated {} allocations of {} total bytes",
344 stack.stochastic_rate,
345 stack.count * stack.stochastic_rate,
346 stack.total_size * stack.stochastic_rate
347 )
348 .ok();
349 write!(result, "{:?}\n\n", stack.bt).ok();
350 }
351 }
352
353 result
354}
355
356#[derive(Deserialize)]
357struct PrometheusMetricsParams {
358 #[serde(default)]
359 prefix: Option<String>,
360}
361
362async fn report_metrics(
363 _: TrustedIpRequired,
364 Query(params): Query<PrometheusMetricsParams>,
365) -> impl IntoResponse {
366 StreamBodyAsOptions::new()
367 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
368 .text(kumo_prometheus::registry::Registry::stream_text(
369 params.prefix.clone(),
370 ))
371}
372
373async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
374 StreamBodyAsOptions::new()
375 .content_type(HttpHeaderValue::from_static(
376 "application/json; charset=utf-8",
377 ))
378 .text(kumo_prometheus::registry::Registry::stream_json())
379}
380
381#[utoipa::path(
385 post,
386 tag="logging",
387 path="/api/admin/set_diagnostic_log_filter/v1",
388 responses(
389 (status = 200, description = "Diagnostic level set successfully")
390 ),
391)]
392async fn set_diagnostic_log_filter_v1(
393 _: TrustedIpRequired,
394 Json(request): Json<SetDiagnosticFilterRequest>,
396) -> Result<(), AppError> {
397 set_diagnostic_log_filter(&request.filter)?;
398 Ok(())
399}