kumo_server_common/http_server/
mod.rs1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use std::sync::Arc;
17use tower_http::compression::CompressionLayer;
18use tower_http::trace::TraceLayer;
19use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
20use utoipa::OpenApi;
21use utoipa_rapidoc::RapiDoc;
22use kumo_api_types::*;
27
28pub mod auth;
29
30use auth::*;
31
32#[derive(OpenApi)]
33#[openapi(
34 info(license(name = "Apache-2.0")),
35 paths(set_diagnostic_log_filter_v1, bump_config_epoch),
36 security(
40 ("basic_auth" = [""])
41 ),
42 components(schemas(SetDiagnosticFilterRequest)),
43 modifiers(&OptionalAuth),
44)]
45struct ApiDoc;
46
47struct OptionalAuth;
48
49impl utoipa::Modify for OptionalAuth {
50 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
51 let components = openapi
52 .components
53 .as_mut()
54 .expect("always set because we always have components above");
55 components.add_security_scheme(
57 "basic_auth",
58 SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
59 );
60 }
61}
62
63#[derive(Deserialize, Clone, Debug)]
64#[serde(deny_unknown_fields)]
65pub struct HttpListenerParams {
66 #[serde(default = "HttpListenerParams::default_hostname")]
67 pub hostname: String,
68
69 #[serde(default = "HttpListenerParams::default_listen")]
70 pub listen: String,
71
72 #[serde(default)]
73 pub use_tls: bool,
74
75 #[serde(default)]
76 pub request_body_limit: Option<usize>,
77
78 #[serde(default)]
79 pub tls_certificate: Option<KeySource>,
80 #[serde(default)]
81 pub tls_private_key: Option<KeySource>,
82
83 #[serde(default = "CidrSet::default_trusted_hosts")]
84 pub trusted_hosts: CidrSet,
85}
86
87pub struct RouterAndDocs {
88 pub router: Router,
89 pub docs: utoipa::openapi::OpenApi,
90}
91
92impl RouterAndDocs {
93 pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
94 let mut api_docs = ApiDoc::openapi();
95 api_docs.info.title = self.docs.info.title.to_string();
96 api_docs.merge(self.docs.clone());
97 api_docs.info.version = version_info::kumo_version().to_string();
98 api_docs.info.license = Some(
99 utoipa::openapi::LicenseBuilder::new()
100 .name("Apache-2.0")
101 .build(),
102 );
103
104 api_docs
105 }
106}
107
108#[derive(Clone)]
109pub struct AppState {
110 trusted_hosts: Arc<CidrSet>,
111}
112
113impl AppState {
114 pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
115 self.trusted_hosts.contains(addr)
116 }
117}
118
119impl HttpListenerParams {
120 fn default_listen() -> String {
121 "127.0.0.1:8000".to_string()
122 }
123
124 fn default_hostname() -> String {
125 gethostname::gethostname()
126 .to_str()
127 .unwrap_or("localhost")
128 .to_string()
129 }
130
131 pub async fn start(
142 self,
143 router_and_docs: RouterAndDocs,
144 runtime: Option<tokio::runtime::Handle>,
145 ) -> anyhow::Result<()> {
146 let api_docs = router_and_docs.make_docs();
147
148 let compression_layer: CompressionLayer = CompressionLayer::new()
149 .deflate(true)
150 .gzip(true)
151 .quality(tower_http::CompressionLevel::Fastest);
152
153 let app = router_and_docs
154 .router
155 .layer(DefaultBodyLimit::max(
156 self.request_body_limit.unwrap_or(2 * 1024 * 1024),
157 ))
158 .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
159 .route(
160 "/api/admin/set_diagnostic_log_filter/v1",
161 post(set_diagnostic_log_filter_v1),
162 )
163 .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
164 .route("/api/admin/memory/stats", get(memory_stats))
165 .route("/metrics", get(report_metrics))
166 .route("/metrics.json", get(report_metrics_json))
167 .route_layer(axum::middleware::from_fn_with_state(
170 AppState {
171 trusted_hosts: Arc::new(self.trusted_hosts.clone()),
172 },
173 auth_middleware,
174 ))
175 .layer(compression_layer)
176 .layer(TraceLayer::new_for_http());
177 let socket = TcpListener::bind(&self.listen)
178 .with_context(|| format!("listen on {}", self.listen))?;
179 let addr = socket.local_addr()?;
180
181 let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
182
183 if self.use_tls {
188 let config = self.tls_config().await?;
189 tracing::info!("https listener on {addr:?}");
190 let server = axum_server::from_tcp_rustls(socket, config);
191 let serve = async move { server.serve(make_service).await };
192
193 if let Some(runtime) = runtime {
194 runtime.spawn(serve);
195 } else {
196 spawn(format!("https {addr:?}"), serve)?;
197 }
198 } else {
199 tracing::info!("http listener on {addr:?}");
200 let server = axum_server::from_tcp(socket);
201 let serve = async move { server.serve(make_service).await };
202 if let Some(runtime) = runtime {
203 runtime.spawn(serve);
204 } else {
205 spawn(format!("http {addr:?}"), serve)?;
206 }
207 }
208 Ok(())
209 }
210
211 async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
212 let config = crate::tls_helpers::make_server_config(
213 &self.hostname,
214 &self.tls_private_key,
215 &self.tls_certificate,
216 )
217 .await?;
218 Ok(RustlsConfig::from_config(config))
219 }
220}
221
222#[derive(Debug)]
223pub struct AppError(pub anyhow::Error);
224
225impl IntoResponse for AppError {
227 fn into_response(self) -> Response {
228 (
229 StatusCode::INTERNAL_SERVER_ERROR,
230 format!("Error: {:#}", self.0),
231 )
232 .into_response()
233 }
234}
235
236impl<E> From<E> for AppError
239where
240 E: Into<anyhow::Error>,
241{
242 fn from(err: E) -> Self {
243 Self(err.into())
244 }
245}
246
247#[utoipa::path(
251 post,
252 tag="config",
253 path="/api/admin/bump-config-epoch",
254 responses(
255 (status=200, description = "bump successful")
256 ),
257)]
258async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
259 config::epoch::bump_current_epoch();
260 Ok(())
261}
262
263#[utoipa::path(
267 get,
268 tag="memory",
269 path="/api/admin/memory/stats",
270 responses(
271 (status=200, description = "stats were returned")
272 ),
273)]
274async fn memory_stats(_: TrustedIpRequired) -> String {
275 use kumo_server_memory::NumBytes;
276 use std::fmt::Write;
277 let mut result = String::new();
278
279 let jstats = JemallocStats::collect();
280 writeln!(result, "{jstats:#?}").ok();
281
282 if let Ok((usage, limit)) = get_usage_and_limit() {
283 writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
284 writeln!(
285 result,
286 "soft limit = {:?}",
287 limit.soft_limit.map(NumBytes::from)
288 )
289 .ok();
290 writeln!(
291 result,
292 "hard limit = {:?}",
293 limit.hard_limit.map(NumBytes::from)
294 )
295 .ok();
296 }
297
298 let mut stats = tracking_stats();
299 writeln!(result, "live = {:?}", stats.live).ok();
300
301 if stats.top_callstacks.is_empty() {
302 write!(
303 result,
304 "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
305 )
306 .ok();
307 } else {
308 writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
309 write!(result, "\ntop call stacks:\n").ok();
310 for stack in &mut stats.top_callstacks {
311 writeln!(
312 result,
313 "sampled every {} allocations, estimated {} allocations of {} total bytes",
314 stack.stochastic_rate,
315 stack.count * stack.stochastic_rate,
316 stack.total_size * stack.stochastic_rate
317 )
318 .ok();
319 write!(result, "{:?}\n\n", stack.bt).ok();
320 }
321 }
322
323 result
324}
325
326#[derive(Deserialize)]
327struct PrometheusMetricsParams {
328 #[serde(default)]
329 prefix: Option<String>,
330}
331
332async fn report_metrics(
333 _: TrustedIpRequired,
334 Query(params): Query<PrometheusMetricsParams>,
335) -> impl IntoResponse {
336 StreamBodyAsOptions::new()
337 .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
338 .text(kumo_prometheus::registry::Registry::stream_text(
339 params.prefix.clone(),
340 ))
341}
342
343async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
344 StreamBodyAsOptions::new()
345 .content_type(HttpHeaderValue::from_static(
346 "application/json; charset=utf-8",
347 ))
348 .text(kumo_prometheus::registry::Registry::stream_json())
349}
350
351#[utoipa::path(
355 post,
356 tag="logging",
357 path="/api/admin/set_diagnostic_log_filter/v1",
358 responses(
359 (status = 200, description = "Diagnostic level set successfully")
360 ),
361)]
362async fn set_diagnostic_log_filter_v1(
363 _: TrustedIpRequired,
364 Json(request): Json<SetDiagnosticFilterRequest>,
366) -> Result<(), AppError> {
367 set_diagnostic_log_filter(&request.filter)?;
368 Ok(())
369}