kumo_server_common/http_server/
mod.rs1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::handler::Handler;
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use tower_http::compression::CompressionLayer;
17use tower_http::decompression::RequestDecompressionLayer;
18use tower_http::trace::TraceLayer;
19use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
20use utoipa::openapi::PathItem;
21use utoipa::OpenApi;
22use utoipa_rapidoc::RapiDoc;
23use kumo_api_types::*;
28
29pub mod auth;
30
31use auth::*;
32
33struct OptionalAuth;
34
35impl utoipa::Modify for OptionalAuth {
36 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
37 let components = openapi
38 .components
39 .as_mut()
40 .expect("always set because we always have components above");
41 components.add_security_scheme(
43 "basic_auth",
44 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
45 );
46 }
47}
48
49#[derive(Deserialize, Clone, Debug)]
50#[serde(deny_unknown_fields)]
51pub struct HttpListenerParams {
52 #[serde(default = "HttpListenerParams::default_hostname")]
53 pub hostname: String,
54
55 #[serde(default = "HttpListenerParams::default_listen")]
56 pub listen: String,
57
58 #[serde(default)]
59 pub use_tls: bool,
60
61 #[serde(default)]
62 pub request_body_limit: Option<usize>,
63
64 #[serde(default)]
65 pub tls_certificate: Option<KeySource>,
66 #[serde(default)]
67 pub tls_private_key: Option<KeySource>,
68
69 #[serde(default = "CidrSet::default_trusted_hosts")]
70 pub trusted_hosts: CidrSet,
71}
72
73pub struct RouterAndDocs {
76 pub router: Router<AppState>,
77 pub docs: utoipa::openapi::OpenApi,
78}
79
80impl RouterAndDocs {
81 fn add_route<T, H: Handler<T, AppState>>(&mut self, path: &str, item: &PathItem, handler: H)
86 where
87 T: 'static,
88 H: 'static,
89 {
90 let router = std::mem::take(&mut self.router);
91 if item.get.is_some() {
92 self.router = router.route(path, axum::routing::get(handler));
93 } else if item.put.is_some() {
94 self.router = router.route(path, axum::routing::put(handler));
95 } else if item.post.is_some() {
96 self.router = router.route(path, axum::routing::post(handler));
97 } else if item.delete.is_some() {
98 self.router = router.route(path, axum::routing::delete(handler));
99 } else {
100 panic!("unhandled path operation");
101 }
102 }
103
104 pub fn register<T, H: Handler<T, AppState>>(
118 &mut self,
119 api: utoipa::openapi::OpenApi,
120 handler: H,
121 ) where
122 T: 'static,
123 H: 'static,
124 {
125 if let Some((path, item)) = api.paths.paths.iter().next() {
126 self.add_route(path, item, handler);
127 } else {
128 panic!("register didn't register any paths!");
129 }
130
131 self.docs.merge(api);
132 }
133
134 pub fn new(title: &str) -> Self {
138 #[derive(OpenApi)]
139 #[openapi(
140 info(
141 license(name="Apache-2.0"),
142 version=version_info::kumo_version()
143 ),
144 security(
148 ("basic_auth" = [""])
149 ),
150 modifiers(&OptionalAuth),
151 )]
152 struct ApiDoc;
153 let mut router = Self {
154 docs: ApiDoc::openapi(),
155 router: Router::new(),
156 };
157
158 router.docs.info.title = title.to_string();
159
160 router.register_common_handlers();
161 router
162 }
163
164 fn register_common_handlers(&mut self) {
166 macro_rules! add_handlers {
167 ($($handler:path $(,)?)*) => {
168 $(
169 {
170 #[derive(OpenApi)]
173 #[openapi(paths($handler))]
174 struct O;
175
176 self.register(O::openapi(), $handler);
177 }
178 )*
179 }
180 }
181
182 add_handlers!(
183 bump_config_epoch,
184 memory_stats,
185 report_metrics,
186 report_metrics_json,
187 set_diagnostic_log_filter_v1,
188 );
189 }
190}
191
192#[macro_export]
199macro_rules! router_with_docs {
200 (title=$title:literal, handlers=[
201 $($handler:path $(,)? )*
202 ]) => {
203 {
204 #![allow(deprecated)]
208
209 let mut router = RouterAndDocs::new($title);
210
211 $(
212 {
226 #[derive(OpenApi)]
227 #[openapi(paths($handler))]
228 struct O;
229
230 router.register(O::openapi(), $handler);
231 }
232 )*
233
234 router
235 }
236 }
237}
238
239#[derive(Clone)]
240pub struct AppState {
241 params: HttpListenerParams,
242 local_addr: SocketAddr,
243}
244
245impl AppState {
246 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
247 self.params.trusted_hosts.contains(addr)
248 }
249
250 pub fn params(&self) -> &HttpListenerParams {
251 &self.params
252 }
253
254 pub fn local_addr(&self) -> &SocketAddr {
255 &self.local_addr
256 }
257}
258
259impl HttpListenerParams {
260 fn default_listen() -> String {
261 "127.0.0.1:8000".to_string()
262 }
263
264 fn default_hostname() -> String {
265 gethostname::gethostname()
266 .to_str()
267 .unwrap_or("localhost")
268 .to_string()
269 }
270
271 pub async fn start(
282 self,
283 router_and_docs: RouterAndDocs,
284 runtime: Option<tokio::runtime::Handle>,
285 ) -> anyhow::Result<()> {
286 let compression_layer: CompressionLayer = CompressionLayer::new()
287 .deflate(true)
288 .gzip(true)
289 .quality(tower_http::CompressionLevel::Fastest);
290 let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
291
292 let socket = TcpListener::bind(&self.listen)
293 .with_context(|| format!("listen on {}", self.listen))?;
294 let addr = socket.local_addr()?;
295
296 let app_state = AppState {
297 params: self.clone(),
298 local_addr: addr.clone(),
299 };
300
301 let app = router_and_docs
302 .router
303 .layer(DefaultBodyLimit::max(
304 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
305 ))
306 .merge(
307 RapiDoc::with_openapi("/api-docs/openapi.json", router_and_docs.docs)
308 .path("/rapidoc"),
309 )
310 .route_layer(axum::middleware::from_fn_with_state(
313 app_state.clone(),
314 auth_middleware,
315 ))
316 .layer(compression_layer)
317 .layer(decompression_layer)
318 .layer(TraceLayer::new_for_http())
319 .layer(axum_client_ip::ClientIpSource::ConnectInfo.into_extension())
320 .with_state(app_state);
321
322 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
323
324 if self.use_tls {
329 let config = self.tls_config().await?;
330 tracing::info!("https listener on {addr:?}");
331 let server = axum_server::from_tcp_rustls(socket, config);
332 let serve = async move { server.serve(make_service).await };
333
334 if let Some(runtime) = runtime {
335 runtime.spawn(serve);
336 } else {
337 spawn(format!("https {addr:?}"), serve)?;
338 }
339 } else {
340 tracing::info!("http listener on {addr:?}");
341 let server = axum_server::from_tcp(socket);
342 let serve = async move { server.serve(make_service).await };
343 if let Some(runtime) = runtime {
344 runtime.spawn(serve);
345 } else {
346 spawn(format!("http {addr:?}"), serve)?;
347 }
348 }
349 Ok(())
350 }
351
352 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
353 let config = crate::tls_helpers::make_server_config(
354 &self.hostname,
355 &self.tls_private_key,
356 &self.tls_certificate,
357 &None,
358 )
359 .await?;
360 Ok(RustlsConfig::from_config(config))
361 }
362}
363
364#[derive(Debug)]
365pub struct AppError {
366 pub err: anyhow::Error,
367 pub code: StatusCode,
368}
369
370impl AppError {
371 pub fn new(code: StatusCode, err: impl Into<String>) -> Self {
372 let err: String = err.into();
373 Self {
374 err: anyhow::anyhow!(err),
375 code,
376 }
377 }
378}
379
380impl IntoResponse for AppError {
382 fn into_response(self) -> Response {
383 (self.code, format!("Error: {:#}", self.err)).into_response()
384 }
385}
386
387impl<E> From<E> for AppError
390where
391 E: Into<anyhow::Error>,
392{
393 fn from(err: E) -> Self {
394 Self {
395 err: err.into(),
396 code: StatusCode::INTERNAL_SERVER_ERROR,
397 }
398 }
399}
400
401#[utoipa::path(
405 post,
406 tag="config",
407 path="/api/admin/bump-config-epoch",
408 responses(
409 (status=200, description = "bump successful")
410 ),
411)]
412async fn bump_config_epoch() -> Result<(), AppError> {
413 config::epoch::bump_current_epoch();
414 Ok(())
415}
416
417#[utoipa::path(
421 get,
422 tag="memory",
423 path="/api/admin/memory/stats",
424 responses(
425 (status=200, description = "stats were returned")
426 ),
427)]
428async fn memory_stats() -> String {
429 use kumo_server_memory::NumBytes;
430 use std::fmt::Write;
431 let mut result = String::new();
432
433 let jstats = JemallocStats::collect();
434 writeln!(result, "{jstats:#?}").ok();
435
436 if let Ok((usage, limit)) = get_usage_and_limit() {
437 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
438 writeln!(
439 result,
440 "soft limit = {:?}",
441 limit.soft_limit.map(NumBytes::from)
442 )
443 .ok();
444 writeln!(
445 result,
446 "hard limit = {:?}",
447 limit.hard_limit.map(NumBytes::from)
448 )
449 .ok();
450 }
451
452 let mut stats = tracking_stats();
453 writeln!(result, "live = {:?}", stats.live).ok();
454
455 if stats.top_callstacks.is_empty() {
456 write!(
457 result,
458 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
459 )
460 .ok();
461 } else {
462 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
463 write!(result, "\ntop call stacks:\n").ok();
464 for stack in &mut stats.top_callstacks {
465 writeln!(
466 result,
467 "sampled every {} allocations, estimated {} allocations of {} total bytes",
468 stack.stochastic_rate,
469 stack.count * stack.stochastic_rate,
470 stack.total_size * stack.stochastic_rate
471 )
472 .ok();
473 write!(result, "{:?}\n\n", stack.bt).ok();
474 }
475 }
476
477 result
478}
479
480#[derive(Deserialize)]
481struct PrometheusMetricsParams {
482 #[serde(default)]
483 prefix: Option<String>,
484}
485
486#[utoipa::path(get, path = "/metrics", responses(
524 (status = 200, content_type="text/plain")
525))]
526async fn report_metrics(Query(params): Query<PrometheusMetricsParams>) -> impl IntoResponse {
527 StreamBodyAsOptions::new()
528 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
529 .text(kumo_prometheus::registry::Registry::stream_text(
530 params.prefix.clone(),
531 ))
532}
533
534#[utoipa::path(get, path = "/metrics.json", responses(
568 (status = 200, content_type="application/json")
569))]
570async fn report_metrics_json() -> impl IntoResponse {
571 StreamBodyAsOptions::new()
572 .content_type(HttpHeaderValue::from_static(
573 "application/json; charset=utf-8",
574 ))
575 .text(kumo_prometheus::registry::Registry::stream_json())
576}
577
578#[utoipa::path(
582 post,
583 tags=["logging", "kcli:set-log-filter"],
584 path="/api/admin/set_diagnostic_log_filter/v1",
585 request_body=SetDiagnosticFilterRequest,
586 responses(
587 (status = 200, description = "Diagnostic level set successfully")
588 ),
589)]
590async fn set_diagnostic_log_filter_v1(
591 Json(request): Json<SetDiagnosticFilterRequest>,
593) -> Result<(), AppError> {
594 set_diagnostic_log_filter(&request.filter)?;
595 Ok(())
596}