1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use crate::start::{MACHINE_INFO, ONLINE_SINCE};
3use anyhow::Context;
4use axum::extract::{DefaultBodyLimit, Json, Query, State};
5use axum::handler::Handler;
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::Router;
9use axum_server::tls_rustls::RustlsConfig;
10use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
11use cidr_map::CidrSet;
12use data_loader::KeySource;
13use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
14use kumo_server_runtime::spawn;
15use serde::Deserialize;
16use std::net::{IpAddr, SocketAddr, TcpListener};
17use tower_http::compression::CompressionLayer;
18use tower_http::decompression::RequestDecompressionLayer;
19use tower_http::trace::TraceLayer;
20use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
21use utoipa::openapi::PathItem;
22use utoipa::OpenApi;
23use utoipa_rapidoc::RapiDoc;
24use kumo_api_types::*;
29
30pub mod auth;
31
32use auth::*;
33
34struct OptionalAuth;
35
36impl utoipa::Modify for OptionalAuth {
37 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
38 let components = openapi
39 .components
40 .as_mut()
41 .expect("always set because we always have components above");
42 components.add_security_scheme(
44 "basic_auth",
45 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
46 );
47 }
48}
49
50#[derive(Deserialize, Clone, Debug)]
51#[serde(deny_unknown_fields)]
52pub struct HttpListenerParams {
53 #[serde(default = "HttpListenerParams::default_hostname")]
54 pub hostname: String,
55
56 #[serde(default = "HttpListenerParams::default_listen")]
57 pub listen: String,
58
59 #[serde(default)]
60 pub use_tls: bool,
61
62 #[serde(default)]
63 pub request_body_limit: Option<usize>,
64
65 #[serde(default)]
66 pub tls_certificate: Option<KeySource>,
67 #[serde(default)]
68 pub tls_private_key: Option<KeySource>,
69
70 #[serde(default = "CidrSet::default_trusted_hosts")]
71 pub trusted_hosts: CidrSet,
72}
73
74pub struct RouterAndDocs {
77 pub router: Router<AppState>,
78 pub docs: utoipa::openapi::OpenApi,
79}
80
81impl RouterAndDocs {
82 fn add_route<T, H: Handler<T, AppState>>(&mut self, path: &str, item: &PathItem, handler: H)
87 where
88 T: 'static,
89 H: 'static,
90 {
91 let router = std::mem::take(&mut self.router);
92 if item.get.is_some() {
93 self.router = router.route(path, axum::routing::get(handler));
94 } else if item.put.is_some() {
95 self.router = router.route(path, axum::routing::put(handler));
96 } else if item.post.is_some() {
97 self.router = router.route(path, axum::routing::post(handler));
98 } else if item.delete.is_some() {
99 self.router = router.route(path, axum::routing::delete(handler));
100 } else {
101 panic!("unhandled path operation");
102 }
103 }
104
105 pub fn register<T, H: Handler<T, AppState>>(
119 &mut self,
120 api: utoipa::openapi::OpenApi,
121 handler: H,
122 ) where
123 T: 'static,
124 H: 'static,
125 {
126 if let Some((path, item)) = api.paths.paths.iter().next() {
127 self.add_route(path, item, handler);
128 } else {
129 panic!("register didn't register any paths!");
130 }
131
132 self.docs.merge(api);
133 }
134
135 pub fn new(title: &str) -> Self {
139 #[derive(OpenApi)]
140 #[openapi(
141 info(
142 license(name="Apache-2.0"),
143 version=version_info::kumo_version()
144 ),
145 security(
149 ("basic_auth" = [""])
150 ),
151 modifiers(&OptionalAuth),
152 )]
153 struct ApiDoc;
154 let mut router = Self {
155 docs: ApiDoc::openapi(),
156 router: Router::new(),
157 };
158
159 router.docs.info.title = title.to_string();
160
161 router.register_common_handlers();
162 router
163 }
164
165 fn register_common_handlers(&mut self) {
167 macro_rules! add_handlers {
168 ($($handler:path $(,)?)*) => {
169 $(
170 {
171 #[derive(OpenApi)]
174 #[openapi(paths($handler))]
175 struct O;
176
177 self.register(O::openapi(), $handler);
178 }
179 )*
180 }
181 }
182
183 add_handlers!(
184 bump_config_epoch,
185 memory_stats,
186 task_dump,
187 report_metrics,
188 report_metrics_json,
189 set_diagnostic_log_filter_v1,
190 machine_info,
191 );
192 }
193}
194
195#[macro_export]
202macro_rules! router_with_docs {
203 (title=$title:literal, handlers=[
204 $($handler:path $(,)? )*
205 ]
206 $(, layers=[
207 $(
208 $layer:expr $(,)?
209 )*
210 ])?
211
212 ) => {
213 {
214 #![allow(deprecated)]
218
219 let mut router = RouterAndDocs::new($title);
220
221 $(
222 {
236 #[derive(OpenApi)]
237 #[openapi(paths($handler))]
238 struct O;
239
240 router.register(O::openapi(), $handler);
241 }
242 )*
243
244 $(
245 $(
246 router.router = router.router.layer($layer);
247 )*
248 )?
249
250 router
251 }
252 }
253}
254
255#[derive(Clone)]
256pub struct AppState {
257 process_kind: String,
258 params: HttpListenerParams,
259 local_addr: SocketAddr,
260}
261
262impl AppState {
263 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
264 self.params.trusted_hosts.contains(addr)
265 }
266
267 pub fn params(&self) -> &HttpListenerParams {
268 &self.params
269 }
270
271 pub fn local_addr(&self) -> &SocketAddr {
272 &self.local_addr
273 }
274}
275
276impl HttpListenerParams {
277 fn default_listen() -> String {
278 "127.0.0.1:8000".to_string()
279 }
280
281 fn default_hostname() -> String {
282 gethostname::gethostname()
283 .to_str()
284 .unwrap_or("localhost")
285 .to_string()
286 }
287
288 pub async fn start(
299 self,
300 router_and_docs: RouterAndDocs,
301 runtime: Option<tokio::runtime::Handle>,
302 ) -> anyhow::Result<()> {
303 let compression_layer: CompressionLayer = CompressionLayer::new()
304 .deflate(true)
305 .gzip(true)
306 .quality(tower_http::CompressionLevel::Fastest);
307 let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
308
309 let socket = TcpListener::bind(&self.listen)
310 .with_context(|| format!("listen on {}", self.listen))?;
311 let addr = socket.local_addr()?;
312
313 let app_state = AppState {
314 process_kind: router_and_docs.docs.info.title.clone(),
315 params: self.clone(),
316 local_addr: addr.clone(),
317 };
318
319 let app = router_and_docs
320 .router
321 .layer(DefaultBodyLimit::max(
322 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
323 ))
324 .merge(
325 RapiDoc::with_openapi("/api-docs/openapi.json", router_and_docs.docs)
326 .path("/rapidoc"),
327 )
328 .route_layer(axum::middleware::from_fn_with_state(
331 app_state.clone(),
332 auth_middleware,
333 ))
334 .layer(compression_layer)
335 .layer(decompression_layer)
336 .layer(TraceLayer::new_for_http())
337 .layer(axum_client_ip::ClientIpSource::ConnectInfo.into_extension())
338 .with_state(app_state);
339
340 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
341
342 if self.use_tls {
347 let config = self.tls_config().await?;
348 tracing::info!("https listener on {addr:?}");
349 let server = axum_server::from_tcp_rustls(socket, config);
350 let serve = async move { server.serve(make_service).await };
351
352 if let Some(runtime) = runtime {
353 runtime.spawn(serve);
354 } else {
355 spawn(format!("https {addr:?}"), serve)?;
356 }
357 } else {
358 tracing::info!("http listener on {addr:?}");
359 let server = axum_server::from_tcp(socket);
360 let serve = async move { server.serve(make_service).await };
361 if let Some(runtime) = runtime {
362 runtime.spawn(serve);
363 } else {
364 spawn(format!("http {addr:?}"), serve)?;
365 }
366 }
367 Ok(())
368 }
369
370 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
371 let config = crate::tls_helpers::make_server_config(
372 &self.hostname,
373 &self.tls_private_key,
374 &self.tls_certificate,
375 &None,
376 )
377 .await?;
378 Ok(RustlsConfig::from_config(config))
379 }
380}
381
382#[derive(Debug)]
383pub struct AppError {
384 pub err: anyhow::Error,
385 pub code: StatusCode,
386}
387
388impl AppError {
389 pub fn new(code: StatusCode, err: impl Into<String>) -> Self {
390 let err: String = err.into();
391 Self {
392 err: anyhow::anyhow!(err),
393 code,
394 }
395 }
396}
397
398impl IntoResponse for AppError {
400 fn into_response(self) -> Response {
401 (self.code, format!("Error: {:#}", self.err)).into_response()
402 }
403}
404
405impl<E> From<E> for AppError
408where
409 E: Into<anyhow::Error>,
410{
411 fn from(err: E) -> Self {
412 Self {
413 err: err.into(),
414 code: StatusCode::INTERNAL_SERVER_ERROR,
415 }
416 }
417}
418
419#[utoipa::path(get, tag = "debugging", path = "/api/machine-info",
421 responses(
422 (status=200, description="Machine information", body=MachineInfoV1)
423 ),
424)]
425async fn machine_info(State(state): State<AppState>) -> Result<Json<MachineInfoV1>, AppError> {
426 let online_since = ONLINE_SINCE.clone();
427 match MACHINE_INFO.lock().as_ref() {
428 Some(info) => Ok(Json(MachineInfoV1 {
429 hostname: info.hostname.clone(),
430 mac_address: info.mac_address.clone(),
431 node_id: info.node_id.clone().unwrap_or_else(String::new),
432 num_cores: info.num_cores,
433 kernel_version: info.kernel_version.clone(),
434 platform: info.platform.clone(),
435 distribution: info.distribution.clone(),
436 os_version: info.os_version.clone(),
437 total_memory_bytes: info.total_memory_bytes.clone(),
438 container_runtime: info.container_runtime.clone(),
439 cpu_brand: info.cpu_brand.clone(),
440 fingerprint: info.fingerprint(),
441 online_since,
442 process_kind: state.process_kind.clone(),
443 version: version_info::kumo_version().to_string(),
444 })),
445 None => {
446 return Err(AppError::new(
447 StatusCode::SERVICE_UNAVAILABLE,
448 "machine info not yet available",
449 ));
450 }
451 }
452}
453
454#[utoipa::path(
458 post,
459 tag="config",
460 path="/api/admin/bump-config-epoch",
461 responses(
462 (status=200, description = "bump successful")
463 ),
464)]
465async fn bump_config_epoch() -> Result<(), AppError> {
466 config::epoch::bump_current_epoch();
467 Ok(())
468}
469
470#[derive(Deserialize)]
471struct TaskDumpParams {
472 #[serde(default)]
473 timeout: Option<u64>,
474}
475
476#[utoipa::path(
491 get,
492 tag="debugging",
493 path="/api/admin/task-dump",
494 responses(
495 (status=200, description="data was returned")
496 ),
497)]
498async fn task_dump(Query(params): Query<TaskDumpParams>) -> String {
499 kumo_server_runtime::dump_all_runtimes(tokio::time::Duration::from_secs(
500 params.timeout.unwrap_or(5),
501 ))
502 .await
503}
504
505#[utoipa::path(
509 get,
510 tag="memory",
511 path="/api/admin/memory/stats",
512 responses(
513 (status=200, description = "stats were returned")
514 ),
515)]
516async fn memory_stats() -> String {
517 use kumo_server_memory::NumBytes;
518 use std::fmt::Write;
519 let mut result = String::new();
520
521 let jstats = JemallocStats::collect();
522 writeln!(result, "{jstats:#?}").ok();
523
524 if let Ok((usage, limit)) = get_usage_and_limit() {
525 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
526 writeln!(
527 result,
528 "soft limit = {:?}",
529 limit.soft_limit.map(NumBytes::from)
530 )
531 .ok();
532 writeln!(
533 result,
534 "hard limit = {:?}",
535 limit.hard_limit.map(NumBytes::from)
536 )
537 .ok();
538 }
539
540 let mut stats = tracking_stats();
541 writeln!(result, "live = {:?}", stats.live).ok();
542
543 if stats.top_callstacks.is_empty() {
544 write!(
545 result,
546 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
547 )
548 .ok();
549 } else {
550 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
551 write!(result, "\ntop call stacks:\n").ok();
552 for stack in &mut stats.top_callstacks {
553 writeln!(
554 result,
555 "sampled every {} allocations, estimated {} allocations of {} total bytes",
556 stack.stochastic_rate,
557 stack.count * stack.stochastic_rate,
558 stack.total_size * stack.stochastic_rate
559 )
560 .ok();
561 write!(result, "{:?}\n\n", stack.bt).ok();
562 }
563 }
564
565 result
566}
567
568#[derive(Deserialize)]
569struct PrometheusMetricsParams {
570 #[serde(default)]
571 prefix: Option<String>,
572}
573
574#[utoipa::path(get, path = "/metrics", responses(
612 (status = 200, content_type="text/plain")
613))]
614async fn report_metrics(Query(params): Query<PrometheusMetricsParams>) -> impl IntoResponse {
615 StreamBodyAsOptions::new()
616 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
617 .text(kumo_prometheus::registry::Registry::stream_text(
618 params.prefix.clone(),
619 ))
620}
621
622#[utoipa::path(get, path = "/metrics.json", responses(
656 (status = 200, content_type="application/json")
657))]
658async fn report_metrics_json() -> impl IntoResponse {
659 StreamBodyAsOptions::new()
660 .content_type(HttpHeaderValue::from_static(
661 "application/json; charset=utf-8",
662 ))
663 .text(kumo_prometheus::registry::Registry::stream_json())
664}
665
666#[utoipa::path(
670 post,
671 tags=["logging", "kcli:set-log-filter"],
672 path="/api/admin/set_diagnostic_log_filter/v1",
673 request_body=SetDiagnosticFilterRequest,
674 responses(
675 (status = 200, description = "Diagnostic level set successfully")
676 ),
677)]
678async fn set_diagnostic_log_filter_v1(
679 Json(request): Json<SetDiagnosticFilterRequest>,
681) -> Result<(), AppError> {
682 set_diagnostic_log_filter(&request.filter)?;
683 Ok(())
684}