kumo_server_common/http_server/
mod.rs1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use std::sync::Arc;
17use tower_http::compression::CompressionLayer;
18use tower_http::decompression::RequestDecompressionLayer;
19use tower_http::trace::TraceLayer;
20use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
21use utoipa::OpenApi;
22use utoipa_rapidoc::RapiDoc;
23use kumo_api_types::*;
28
29pub mod auth;
30
31use auth::*;
32
33#[derive(OpenApi)]
34#[openapi(
35 info(license(name = "Apache-2.0")),
36 paths(set_diagnostic_log_filter_v1, bump_config_epoch),
37 security(
41 ("basic_auth" = [""])
42 ),
43 components(schemas(SetDiagnosticFilterRequest)),
44 modifiers(&OptionalAuth),
45)]
46struct ApiDoc;
47
48struct OptionalAuth;
49
50impl utoipa::Modify for OptionalAuth {
51 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
52 let components = openapi
53 .components
54 .as_mut()
55 .expect("always set because we always have components above");
56 components.add_security_scheme(
58 "basic_auth",
59 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
60 );
61 }
62}
63
64#[derive(Deserialize, Clone, Debug)]
65#[serde(deny_unknown_fields)]
66pub struct HttpListenerParams {
67 #[serde(default = "HttpListenerParams::default_hostname")]
68 pub hostname: String,
69
70 #[serde(default = "HttpListenerParams::default_listen")]
71 pub listen: String,
72
73 #[serde(default)]
74 pub use_tls: bool,
75
76 #[serde(default)]
77 pub request_body_limit: Option<usize>,
78
79 #[serde(default)]
80 pub tls_certificate: Option<KeySource>,
81 #[serde(default)]
82 pub tls_private_key: Option<KeySource>,
83
84 #[serde(default = "CidrSet::default_trusted_hosts")]
85 pub trusted_hosts: CidrSet,
86}
87
88pub struct RouterAndDocs {
89 pub router: Router,
90 pub docs: utoipa::openapi::OpenApi,
91}
92
93impl RouterAndDocs {
94 pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
95 let mut api_docs = ApiDoc::openapi();
96 api_docs.info.title = self.docs.info.title.to_string();
97 api_docs.merge(self.docs.clone());
98 api_docs.info.version = version_info::kumo_version().to_string();
99 api_docs.info.license = Some(
100 utoipa::openapi::LicenseBuilder::new()
101 .name("Apache-2.0")
102 .build(),
103 );
104
105 api_docs
106 }
107}
108
109#[derive(Clone)]
110pub struct AppState {
111 trusted_hosts: Arc<CidrSet>,
112}
113
114impl AppState {
115 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
116 self.trusted_hosts.contains(addr)
117 }
118}
119
120impl HttpListenerParams {
121 fn default_listen() -> String {
122 "127.0.0.1:8000".to_string()
123 }
124
125 fn default_hostname() -> String {
126 gethostname::gethostname()
127 .to_str()
128 .unwrap_or("localhost")
129 .to_string()
130 }
131
132 pub async fn start(
143 self,
144 router_and_docs: RouterAndDocs,
145 runtime: Option<tokio::runtime::Handle>,
146 ) -> anyhow::Result<()> {
147 let api_docs = router_and_docs.make_docs();
148
149 let compression_layer: CompressionLayer = CompressionLayer::new()
150 .deflate(true)
151 .gzip(true)
152 .quality(tower_http::CompressionLevel::Fastest);
153 let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
154 let app = router_and_docs
155 .router
156 .layer(DefaultBodyLimit::max(
157 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
158 ))
159 .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
160 .route(
161 "/api/admin/set_diagnostic_log_filter/v1",
162 post(set_diagnostic_log_filter_v1),
163 )
164 .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
165 .route("/api/admin/memory/stats", get(memory_stats))
166 .route("/metrics", get(report_metrics))
167 .route("/metrics.json", get(report_metrics_json))
168 .route_layer(axum::middleware::from_fn_with_state(
171 AppState {
172 trusted_hosts: Arc::new(self.trusted_hosts.clone()),
173 },
174 auth_middleware,
175 ))
176 .layer(compression_layer)
177 .layer(decompression_layer)
178 .layer(TraceLayer::new_for_http());
179 let socket = TcpListener::bind(&self.listen)
180 .with_context(|| format!("listen on {}", self.listen))?;
181 let addr = socket.local_addr()?;
182
183 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
184
185 if self.use_tls {
190 let config = self.tls_config().await?;
191 tracing::info!("https listener on {addr:?}");
192 let server = axum_server::from_tcp_rustls(socket, config);
193 let serve = async move { server.serve(make_service).await };
194
195 if let Some(runtime) = runtime {
196 runtime.spawn(serve);
197 } else {
198 spawn(format!("https {addr:?}"), serve)?;
199 }
200 } else {
201 tracing::info!("http listener on {addr:?}");
202 let server = axum_server::from_tcp(socket);
203 let serve = async move { server.serve(make_service).await };
204 if let Some(runtime) = runtime {
205 runtime.spawn(serve);
206 } else {
207 spawn(format!("http {addr:?}"), serve)?;
208 }
209 }
210 Ok(())
211 }
212
213 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
214 let config = crate::tls_helpers::make_server_config(
215 &self.hostname,
216 &self.tls_private_key,
217 &self.tls_certificate,
218 &None,
219 )
220 .await?;
221 Ok(RustlsConfig::from_config(config))
222 }
223}
224
225#[derive(Debug)]
226pub struct AppError(pub anyhow::Error);
227
228impl IntoResponse for AppError {
230 fn into_response(self) -> Response {
231 (
232 StatusCode::INTERNAL_SERVER_ERROR,
233 format!("Error: {:#}", self.0),
234 )
235 .into_response()
236 }
237}
238
239impl<E> From<E> for AppError
242where
243 E: Into<anyhow::Error>,
244{
245 fn from(err: E) -> Self {
246 Self(err.into())
247 }
248}
249
250#[utoipa::path(
254 post,
255 tag="config",
256 path="/api/admin/bump-config-epoch",
257 responses(
258 (status=200, description = "bump successful")
259 ),
260)]
261async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
262 config::epoch::bump_current_epoch();
263 Ok(())
264}
265
266#[utoipa::path(
270 get,
271 tag="memory",
272 path="/api/admin/memory/stats",
273 responses(
274 (status=200, description = "stats were returned")
275 ),
276)]
277async fn memory_stats(_: TrustedIpRequired) -> String {
278 use kumo_server_memory::NumBytes;
279 use std::fmt::Write;
280 let mut result = String::new();
281
282 let jstats = JemallocStats::collect();
283 writeln!(result, "{jstats:#?}").ok();
284
285 if let Ok((usage, limit)) = get_usage_and_limit() {
286 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
287 writeln!(
288 result,
289 "soft limit = {:?}",
290 limit.soft_limit.map(NumBytes::from)
291 )
292 .ok();
293 writeln!(
294 result,
295 "hard limit = {:?}",
296 limit.hard_limit.map(NumBytes::from)
297 )
298 .ok();
299 }
300
301 let mut stats = tracking_stats();
302 writeln!(result, "live = {:?}", stats.live).ok();
303
304 if stats.top_callstacks.is_empty() {
305 write!(
306 result,
307 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
308 )
309 .ok();
310 } else {
311 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
312 write!(result, "\ntop call stacks:\n").ok();
313 for stack in &mut stats.top_callstacks {
314 writeln!(
315 result,
316 "sampled every {} allocations, estimated {} allocations of {} total bytes",
317 stack.stochastic_rate,
318 stack.count * stack.stochastic_rate,
319 stack.total_size * stack.stochastic_rate
320 )
321 .ok();
322 write!(result, "{:?}\n\n", stack.bt).ok();
323 }
324 }
325
326 result
327}
328
329#[derive(Deserialize)]
330struct PrometheusMetricsParams {
331 #[serde(default)]
332 prefix: Option<String>,
333}
334
335async fn report_metrics(
336 _: TrustedIpRequired,
337 Query(params): Query<PrometheusMetricsParams>,
338) -> impl IntoResponse {
339 StreamBodyAsOptions::new()
340 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
341 .text(kumo_prometheus::registry::Registry::stream_text(
342 params.prefix.clone(),
343 ))
344}
345
346async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
347 StreamBodyAsOptions::new()
348 .content_type(HttpHeaderValue::from_static(
349 "application/json; charset=utf-8",
350 ))
351 .text(kumo_prometheus::registry::Registry::stream_json())
352}
353
354#[utoipa::path(
358 post,
359 tag="logging",
360 path="/api/admin/set_diagnostic_log_filter/v1",
361 responses(
362 (status = 200, description = "Diagnostic level set successfully")
363 ),
364)]
365async fn set_diagnostic_log_filter_v1(
366 _: TrustedIpRequired,
367 Json(request): Json<SetDiagnosticFilterRequest>,
369) -> Result<(), AppError> {
370 set_diagnostic_log_filter(&request.filter)?;
371 Ok(())
372}