kumo_server_common/http_server/
mod.rs

1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use std::sync::Arc;
17use tower_http::compression::CompressionLayer;
18use tower_http::trace::TraceLayer;
19use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
20use utoipa::OpenApi;
21use utoipa_rapidoc::RapiDoc;
22// Avoid referencing api types as crate::name in the utoipa macros,
23// otherwise it generates namespaced names in the openapi.json, which
24// in turn require annotating each and every struct with the namespace
25// in order for the document to be valid.
26use kumo_api_types::*;
27
28pub mod auth;
29
30use auth::*;
31
32#[derive(OpenApi)]
33#[openapi(
34    info(license(name = "Apache-2.0")),
35    paths(set_diagnostic_log_filter_v1, bump_config_epoch),
36    // Indicate that all paths can accept http basic auth.
37    // the "basic_auth" name corresponds with the scheme
38    // defined by the OptionalAuth addon defined below
39    security(
40        ("basic_auth" = [""])
41    ),
42    components(schemas(SetDiagnosticFilterRequest)),
43    modifiers(&OptionalAuth),
44)]
45struct ApiDoc;
46
47struct OptionalAuth;
48
49impl utoipa::Modify for OptionalAuth {
50    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
51        let components = openapi
52            .components
53            .as_mut()
54            .expect("always set because we always have components above");
55        // Define basic_auth as http basic auth
56        components.add_security_scheme(
57            "basic_auth",
58            SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
59        );
60    }
61}
62
63#[derive(Deserialize, Clone, Debug)]
64#[serde(deny_unknown_fields)]
65pub struct HttpListenerParams {
66    #[serde(default = "HttpListenerParams::default_hostname")]
67    pub hostname: String,
68
69    #[serde(default = "HttpListenerParams::default_listen")]
70    pub listen: String,
71
72    #[serde(default)]
73    pub use_tls: bool,
74
75    #[serde(default)]
76    pub request_body_limit: Option<usize>,
77
78    #[serde(default)]
79    pub tls_certificate: Option<KeySource>,
80    #[serde(default)]
81    pub tls_private_key: Option<KeySource>,
82
83    #[serde(default = "CidrSet::default_trusted_hosts")]
84    pub trusted_hosts: CidrSet,
85}
86
87pub struct RouterAndDocs {
88    pub router: Router,
89    pub docs: utoipa::openapi::OpenApi,
90}
91
92impl RouterAndDocs {
93    pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
94        let mut api_docs = ApiDoc::openapi();
95        api_docs.info.title = self.docs.info.title.to_string();
96        api_docs.merge(self.docs.clone());
97        api_docs.info.version = version_info::kumo_version().to_string();
98        api_docs.info.license = Some(
99            utoipa::openapi::LicenseBuilder::new()
100                .name("Apache-2.0")
101                .build(),
102        );
103
104        api_docs
105    }
106}
107
108#[derive(Clone)]
109pub struct AppState {
110    trusted_hosts: Arc<CidrSet>,
111}
112
113impl AppState {
114    pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
115        self.trusted_hosts.contains(addr)
116    }
117}
118
119impl HttpListenerParams {
120    fn default_listen() -> String {
121        "127.0.0.1:8000".to_string()
122    }
123
124    fn default_hostname() -> String {
125        gethostname::gethostname()
126            .to_str()
127            .unwrap_or("localhost")
128            .to_string()
129    }
130
131    // Note: it is possible to call
132    // server.with_graceful_shutdown(ShutdownSubscription::get().shutting_down)
133    // to have it listen for a shutdown request, but we're avoiding it:
134    // the request is the start of a shutdown and we need to allow a grace
135    // period for in-flight operations to complete.
136    // Some of those may require call backs to the HTTP endpoint
137    // if we're doing some kind of web hook like thing.
138    // So, for now at least, we'll have to manually verify if
139    // a request should proceed based on the results from the lifecycle
140    // module.
141    pub async fn start(
142        self,
143        router_and_docs: RouterAndDocs,
144        runtime: Option<tokio::runtime::Handle>,
145    ) -> anyhow::Result<()> {
146        let api_docs = router_and_docs.make_docs();
147
148        let compression_layer: CompressionLayer = CompressionLayer::new()
149            .deflate(true)
150            .gzip(true)
151            .quality(tower_http::CompressionLevel::Fastest);
152
153        let app = router_and_docs
154            .router
155            .layer(DefaultBodyLimit::max(
156                self.request_body_limit.unwrap_or(2 * 1024 * 1024),
157            ))
158            .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
159            .route(
160                "/api/admin/set_diagnostic_log_filter/v1",
161                post(set_diagnostic_log_filter_v1),
162            )
163            .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
164            .route("/api/admin/memory/stats", get(memory_stats))
165            .route("/metrics", get(report_metrics))
166            .route("/metrics.json", get(report_metrics_json))
167            // Require that all requests be authenticated as either coming
168            // from a trusted IP address, or with an authorization header
169            .route_layer(axum::middleware::from_fn_with_state(
170                AppState {
171                    trusted_hosts: Arc::new(self.trusted_hosts.clone()),
172                },
173                auth_middleware,
174            ))
175            .layer(compression_layer)
176            .layer(TraceLayer::new_for_http());
177        let socket = TcpListener::bind(&self.listen)
178            .with_context(|| format!("listen on {}", self.listen))?;
179        let addr = socket.local_addr()?;
180
181        let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
182
183        // The logic below is a bit repeatey, but it is still fewer
184        // lines of magic than it would be to factor out into a
185        // generic function because of all of the trait bounds
186        // that it would require.
187        if self.use_tls {
188            let config = self.tls_config().await?;
189            tracing::info!("https listener on {addr:?}");
190            let server = axum_server::from_tcp_rustls(socket, config);
191            let serve = async move { server.serve(make_service).await };
192
193            if let Some(runtime) = runtime {
194                runtime.spawn(serve);
195            } else {
196                spawn(format!("https {addr:?}"), serve)?;
197            }
198        } else {
199            tracing::info!("http listener on {addr:?}");
200            let server = axum_server::from_tcp(socket);
201            let serve = async move { server.serve(make_service).await };
202            if let Some(runtime) = runtime {
203                runtime.spawn(serve);
204            } else {
205                spawn(format!("http {addr:?}"), serve)?;
206            }
207        }
208        Ok(())
209    }
210
211    async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
212        let config = crate::tls_helpers::make_server_config(
213            &self.hostname,
214            &self.tls_private_key,
215            &self.tls_certificate,
216        )
217        .await?;
218        Ok(RustlsConfig::from_config(config))
219    }
220}
221
222#[derive(Debug)]
223pub struct AppError(pub anyhow::Error);
224
225// Tell axum how to convert `AppError` into a response.
226impl IntoResponse for AppError {
227    fn into_response(self) -> Response {
228        (
229            StatusCode::INTERNAL_SERVER_ERROR,
230            format!("Error: {:#}", self.0),
231        )
232            .into_response()
233    }
234}
235
236// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
237// `Result<_, AppError>`. That way you don't need to do that manually.
238impl<E> From<E> for AppError
239where
240    E: Into<anyhow::Error>,
241{
242    fn from(err: E) -> Self {
243        Self(err.into())
244    }
245}
246
247/// Allows the system operator to trigger a configuration epoch bump,
248/// which causes various configs that are using the Epoch strategy to
249/// be re-evaluated by triggering the appropriate callbacks.
250#[utoipa::path(
251    post,
252    tag="config",
253    path="/api/admin/bump-config-epoch",
254    responses(
255        (status=200, description = "bump successful")
256    ),
257)]
258async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
259    config::epoch::bump_current_epoch();
260    Ok(())
261}
262
263/// Returns information about the system memory usage in an unstructured
264/// human readable format.  The output is not machine parseable and may
265/// change without notice between versions of kumomta.
266#[utoipa::path(
267    get,
268    tag="memory",
269    path="/api/admin/memory/stats",
270    responses(
271        (status=200, description = "stats were returned")
272    ),
273)]
274async fn memory_stats(_: TrustedIpRequired) -> String {
275    use kumo_server_memory::NumBytes;
276    use std::fmt::Write;
277    let mut result = String::new();
278
279    let jstats = JemallocStats::collect();
280    writeln!(result, "{jstats:#?}").ok();
281
282    if let Ok((usage, limit)) = get_usage_and_limit() {
283        writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
284        writeln!(
285            result,
286            "soft limit = {:?}",
287            limit.soft_limit.map(NumBytes::from)
288        )
289        .ok();
290        writeln!(
291            result,
292            "hard limit = {:?}",
293            limit.hard_limit.map(NumBytes::from)
294        )
295        .ok();
296    }
297
298    let mut stats = tracking_stats();
299    writeln!(result, "live = {:?}", stats.live).ok();
300
301    if stats.top_callstacks.is_empty() {
302        write!(
303            result,
304            "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
305        )
306        .ok();
307    } else {
308        writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
309        write!(result, "\ntop call stacks:\n").ok();
310        for stack in &mut stats.top_callstacks {
311            writeln!(
312                result,
313                "sampled every {} allocations, estimated {} allocations of {} total bytes",
314                stack.stochastic_rate,
315                stack.count * stack.stochastic_rate,
316                stack.total_size * stack.stochastic_rate
317            )
318            .ok();
319            write!(result, "{:?}\n\n", stack.bt).ok();
320        }
321    }
322
323    result
324}
325
326#[derive(Deserialize)]
327struct PrometheusMetricsParams {
328    #[serde(default)]
329    prefix: Option<String>,
330}
331
332async fn report_metrics(
333    _: TrustedIpRequired,
334    Query(params): Query<PrometheusMetricsParams>,
335) -> impl IntoResponse {
336    StreamBodyAsOptions::new()
337        .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
338        .text(kumo_prometheus::registry::Registry::stream_text(
339            params.prefix.clone(),
340        ))
341}
342
343async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
344    StreamBodyAsOptions::new()
345        .content_type(HttpHeaderValue::from_static(
346            "application/json; charset=utf-8",
347        ))
348        .text(kumo_prometheus::registry::Registry::stream_json())
349}
350
351/// Changes the diagnostic log filter dynamically.
352/// See <https://docs.kumomta.com/reference/kumo/set_diagnostic_log_filter/>
353/// for more information on diagnostic log filters.
354#[utoipa::path(
355    post,
356    tag="logging",
357    path="/api/admin/set_diagnostic_log_filter/v1",
358    responses(
359        (status = 200, description = "Diagnostic level set successfully")
360    ),
361)]
362async fn set_diagnostic_log_filter_v1(
363    _: TrustedIpRequired,
364    // Note: Json<> must be last in the param list
365    Json(request): Json<SetDiagnosticFilterRequest>,
366) -> Result<(), AppError> {
367    set_diagnostic_log_filter(&request.filter)?;
368    Ok(())
369}