1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use crate::start::{MACHINE_INFO, ONLINE_SINCE};
3use anyhow::Context;
4use axum::extract::{DefaultBodyLimit, Json, Query, State};
5use axum::handler::Handler;
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::Router;
9use axum_server::tls_rustls::RustlsConfig;
10use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
11use cidr_map::CidrSet;
12use data_loader::KeySource;
13use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
14use kumo_server_runtime::spawn;
15use serde::Deserialize;
16use std::net::{IpAddr, SocketAddr, TcpListener};
17use tower_http::compression::CompressionLayer;
18use tower_http::decompression::RequestDecompressionLayer;
19use tower_http::trace::TraceLayer;
20use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
21use utoipa::openapi::PathItem;
22use utoipa::OpenApi;
23use utoipa_rapidoc::RapiDoc;
24use kumo_api_types::*;
29
30pub mod auth;
31
32use auth::*;
33
34struct OptionalAuth;
35
36impl utoipa::Modify for OptionalAuth {
37 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
38 let components = openapi
39 .components
40 .as_mut()
41 .expect("always set because we always have components above");
42 components.add_security_scheme(
44 "basic_auth",
45 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
46 );
47 }
48}
49
50#[derive(Deserialize, Clone, Debug)]
51#[serde(deny_unknown_fields)]
52pub struct HttpListenerParams {
53 #[serde(default = "HttpListenerParams::default_hostname")]
54 pub hostname: String,
55
56 #[serde(default = "HttpListenerParams::default_listen")]
57 pub listen: String,
58
59 #[serde(default)]
60 pub use_tls: bool,
61
62 #[serde(default)]
63 pub request_body_limit: Option<usize>,
64
65 #[serde(default)]
66 pub tls_certificate: Option<KeySource>,
67 #[serde(default)]
68 pub tls_private_key: Option<KeySource>,
69
70 #[serde(default = "CidrSet::default_trusted_hosts")]
71 pub trusted_hosts: CidrSet,
72}
73
74pub struct RouterAndDocs {
77 pub router: Router<AppState>,
78 pub docs: utoipa::openapi::OpenApi,
79}
80
81impl RouterAndDocs {
82 fn add_route<T, H: Handler<T, AppState>>(&mut self, path: &str, item: &PathItem, handler: H)
87 where
88 T: 'static,
89 H: 'static,
90 {
91 let router = std::mem::take(&mut self.router);
92 if item.get.is_some() {
93 self.router = router.route(path, axum::routing::get(handler));
94 } else if item.put.is_some() {
95 self.router = router.route(path, axum::routing::put(handler));
96 } else if item.post.is_some() {
97 self.router = router.route(path, axum::routing::post(handler));
98 } else if item.delete.is_some() {
99 self.router = router.route(path, axum::routing::delete(handler));
100 } else {
101 panic!("unhandled path operation");
102 }
103 }
104
105 pub fn register<T, H: Handler<T, AppState>>(
119 &mut self,
120 api: utoipa::openapi::OpenApi,
121 handler: H,
122 ) where
123 T: 'static,
124 H: 'static,
125 {
126 if let Some((path, item)) = api.paths.paths.iter().next() {
127 self.add_route(path, item, handler);
128 } else {
129 panic!("register didn't register any paths!");
130 }
131
132 self.docs.merge(api);
133 }
134
135 pub fn new(title: &str) -> Self {
139 #[derive(OpenApi)]
140 #[openapi(
141 info(
142 license(name="Apache-2.0"),
143 version=version_info::kumo_version()
144 ),
145 security(
149 ("basic_auth" = [""])
150 ),
151 modifiers(&OptionalAuth),
152 )]
153 struct ApiDoc;
154 let mut router = Self {
155 docs: ApiDoc::openapi(),
156 router: Router::new(),
157 };
158
159 router.docs.info.title = title.to_string();
160
161 router.register_common_handlers();
162 router
163 }
164
165 fn register_common_handlers(&mut self) {
167 macro_rules! add_handlers {
168 ($($handler:path $(,)?)*) => {
169 $(
170 {
171 #[derive(OpenApi)]
174 #[openapi(paths($handler))]
175 struct O;
176
177 self.register(O::openapi(), $handler);
178 }
179 )*
180 }
181 }
182
183 add_handlers!(
184 bump_config_epoch,
185 memory_stats,
186 report_metrics,
187 report_metrics_json,
188 set_diagnostic_log_filter_v1,
189 machine_info,
190 );
191 }
192}
193
194#[macro_export]
201macro_rules! router_with_docs {
202 (title=$title:literal, handlers=[
203 $($handler:path $(,)? )*
204 ]
205 $(, layers=[
206 $(
207 $layer:expr $(,)?
208 )*
209 ])?
210
211 ) => {
212 {
213 #![allow(deprecated)]
217
218 let mut router = RouterAndDocs::new($title);
219
220 $(
221 {
235 #[derive(OpenApi)]
236 #[openapi(paths($handler))]
237 struct O;
238
239 router.register(O::openapi(), $handler);
240 }
241 )*
242
243 $(
244 $(
245 router.router = router.router.layer($layer);
246 )*
247 )?
248
249 router
250 }
251 }
252}
253
254#[derive(Clone)]
255pub struct AppState {
256 process_kind: String,
257 params: HttpListenerParams,
258 local_addr: SocketAddr,
259}
260
261impl AppState {
262 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
263 self.params.trusted_hosts.contains(addr)
264 }
265
266 pub fn params(&self) -> &HttpListenerParams {
267 &self.params
268 }
269
270 pub fn local_addr(&self) -> &SocketAddr {
271 &self.local_addr
272 }
273}
274
275impl HttpListenerParams {
276 fn default_listen() -> String {
277 "127.0.0.1:8000".to_string()
278 }
279
280 fn default_hostname() -> String {
281 gethostname::gethostname()
282 .to_str()
283 .unwrap_or("localhost")
284 .to_string()
285 }
286
287 pub async fn start(
298 self,
299 router_and_docs: RouterAndDocs,
300 runtime: Option<tokio::runtime::Handle>,
301 ) -> anyhow::Result<()> {
302 let compression_layer: CompressionLayer = CompressionLayer::new()
303 .deflate(true)
304 .gzip(true)
305 .quality(tower_http::CompressionLevel::Fastest);
306 let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
307
308 let socket = TcpListener::bind(&self.listen)
309 .with_context(|| format!("listen on {}", self.listen))?;
310 let addr = socket.local_addr()?;
311
312 let app_state = AppState {
313 process_kind: router_and_docs.docs.info.title.clone(),
314 params: self.clone(),
315 local_addr: addr.clone(),
316 };
317
318 let app = router_and_docs
319 .router
320 .layer(DefaultBodyLimit::max(
321 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
322 ))
323 .merge(
324 RapiDoc::with_openapi("/api-docs/openapi.json", router_and_docs.docs)
325 .path("/rapidoc"),
326 )
327 .route_layer(axum::middleware::from_fn_with_state(
330 app_state.clone(),
331 auth_middleware,
332 ))
333 .layer(compression_layer)
334 .layer(decompression_layer)
335 .layer(TraceLayer::new_for_http())
336 .layer(axum_client_ip::ClientIpSource::ConnectInfo.into_extension())
337 .with_state(app_state);
338
339 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
340
341 if self.use_tls {
346 let config = self.tls_config().await?;
347 tracing::info!("https listener on {addr:?}");
348 let server = axum_server::from_tcp_rustls(socket, config);
349 let serve = async move { server.serve(make_service).await };
350
351 if let Some(runtime) = runtime {
352 runtime.spawn(serve);
353 } else {
354 spawn(format!("https {addr:?}"), serve)?;
355 }
356 } else {
357 tracing::info!("http listener on {addr:?}");
358 let server = axum_server::from_tcp(socket);
359 let serve = async move { server.serve(make_service).await };
360 if let Some(runtime) = runtime {
361 runtime.spawn(serve);
362 } else {
363 spawn(format!("http {addr:?}"), serve)?;
364 }
365 }
366 Ok(())
367 }
368
369 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
370 let config = crate::tls_helpers::make_server_config(
371 &self.hostname,
372 &self.tls_private_key,
373 &self.tls_certificate,
374 &None,
375 )
376 .await?;
377 Ok(RustlsConfig::from_config(config))
378 }
379}
380
381#[derive(Debug)]
382pub struct AppError {
383 pub err: anyhow::Error,
384 pub code: StatusCode,
385}
386
387impl AppError {
388 pub fn new(code: StatusCode, err: impl Into<String>) -> Self {
389 let err: String = err.into();
390 Self {
391 err: anyhow::anyhow!(err),
392 code,
393 }
394 }
395}
396
397impl IntoResponse for AppError {
399 fn into_response(self) -> Response {
400 (self.code, format!("Error: {:#}", self.err)).into_response()
401 }
402}
403
404impl<E> From<E> for AppError
407where
408 E: Into<anyhow::Error>,
409{
410 fn from(err: E) -> Self {
411 Self {
412 err: err.into(),
413 code: StatusCode::INTERNAL_SERVER_ERROR,
414 }
415 }
416}
417
418#[utoipa::path(get, tag = "debugging", path = "/api/machine-info",
420 responses(
421 (status=200, description="Machine information", body=MachineInfoV1)
422 ),
423)]
424async fn machine_info(State(state): State<AppState>) -> Result<Json<MachineInfoV1>, AppError> {
425 let online_since = ONLINE_SINCE.clone();
426 match MACHINE_INFO.lock().as_ref() {
427 Some(info) => Ok(Json(MachineInfoV1 {
428 hostname: info.hostname.clone(),
429 mac_address: info.mac_address.clone(),
430 node_id: info.node_id.clone().unwrap_or_else(String::new),
431 num_cores: info.num_cores,
432 kernel_version: info.kernel_version.clone(),
433 platform: info.platform.clone(),
434 distribution: info.distribution.clone(),
435 os_version: info.os_version.clone(),
436 total_memory_bytes: info.total_memory_bytes.clone(),
437 container_runtime: info.container_runtime.clone(),
438 cpu_brand: info.cpu_brand.clone(),
439 fingerprint: info.fingerprint(),
440 online_since,
441 process_kind: state.process_kind.clone(),
442 version: version_info::kumo_version().to_string(),
443 })),
444 None => {
445 return Err(AppError::new(
446 StatusCode::SERVICE_UNAVAILABLE,
447 "machine info not yet available",
448 ));
449 }
450 }
451}
452
453#[utoipa::path(
457 post,
458 tag="config",
459 path="/api/admin/bump-config-epoch",
460 responses(
461 (status=200, description = "bump successful")
462 ),
463)]
464async fn bump_config_epoch() -> Result<(), AppError> {
465 config::epoch::bump_current_epoch();
466 Ok(())
467}
468
469#[utoipa::path(
473 get,
474 tag="memory",
475 path="/api/admin/memory/stats",
476 responses(
477 (status=200, description = "stats were returned")
478 ),
479)]
480async fn memory_stats() -> String {
481 use kumo_server_memory::NumBytes;
482 use std::fmt::Write;
483 let mut result = String::new();
484
485 let jstats = JemallocStats::collect();
486 writeln!(result, "{jstats:#?}").ok();
487
488 if let Ok((usage, limit)) = get_usage_and_limit() {
489 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
490 writeln!(
491 result,
492 "soft limit = {:?}",
493 limit.soft_limit.map(NumBytes::from)
494 )
495 .ok();
496 writeln!(
497 result,
498 "hard limit = {:?}",
499 limit.hard_limit.map(NumBytes::from)
500 )
501 .ok();
502 }
503
504 let mut stats = tracking_stats();
505 writeln!(result, "live = {:?}", stats.live).ok();
506
507 if stats.top_callstacks.is_empty() {
508 write!(
509 result,
510 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
511 )
512 .ok();
513 } else {
514 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
515 write!(result, "\ntop call stacks:\n").ok();
516 for stack in &mut stats.top_callstacks {
517 writeln!(
518 result,
519 "sampled every {} allocations, estimated {} allocations of {} total bytes",
520 stack.stochastic_rate,
521 stack.count * stack.stochastic_rate,
522 stack.total_size * stack.stochastic_rate
523 )
524 .ok();
525 write!(result, "{:?}\n\n", stack.bt).ok();
526 }
527 }
528
529 result
530}
531
532#[derive(Deserialize)]
533struct PrometheusMetricsParams {
534 #[serde(default)]
535 prefix: Option<String>,
536}
537
538#[utoipa::path(get, path = "/metrics", responses(
576 (status = 200, content_type="text/plain")
577))]
578async fn report_metrics(Query(params): Query<PrometheusMetricsParams>) -> impl IntoResponse {
579 StreamBodyAsOptions::new()
580 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
581 .text(kumo_prometheus::registry::Registry::stream_text(
582 params.prefix.clone(),
583 ))
584}
585
586#[utoipa::path(get, path = "/metrics.json", responses(
620 (status = 200, content_type="application/json")
621))]
622async fn report_metrics_json() -> impl IntoResponse {
623 StreamBodyAsOptions::new()
624 .content_type(HttpHeaderValue::from_static(
625 "application/json; charset=utf-8",
626 ))
627 .text(kumo_prometheus::registry::Registry::stream_json())
628}
629
630#[utoipa::path(
634 post,
635 tags=["logging", "kcli:set-log-filter"],
636 path="/api/admin/set_diagnostic_log_filter/v1",
637 request_body=SetDiagnosticFilterRequest,
638 responses(
639 (status = 200, description = "Diagnostic level set successfully")
640 ),
641)]
642async fn set_diagnostic_log_filter_v1(
643 Json(request): Json<SetDiagnosticFilterRequest>,
645) -> Result<(), AppError> {
646 set_diagnostic_log_filter(&request.filter)?;
647 Ok(())
648}