kumo_server_common/http_server/
mod.rs

1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use tower_http::compression::CompressionLayer;
17use tower_http::decompression::RequestDecompressionLayer;
18use tower_http::trace::TraceLayer;
19use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
20use utoipa::OpenApi;
21use utoipa_rapidoc::RapiDoc;
22// Avoid referencing api types as crate::name in the utoipa macros,
23// otherwise it generates namespaced names in the openapi.json, which
24// in turn require annotating each and every struct with the namespace
25// in order for the document to be valid.
26use kumo_api_types::*;
27
28pub mod auth;
29
30use auth::*;
31
32#[derive(OpenApi)]
33#[openapi(
34    info(license(name = "Apache-2.0")),
35    paths(set_diagnostic_log_filter_v1, bump_config_epoch),
36    // Indicate that all paths can accept http basic auth.
37    // the "basic_auth" name corresponds with the scheme
38    // defined by the OptionalAuth addon defined below
39    security(
40        ("basic_auth" = [""])
41    ),
42    components(schemas(SetDiagnosticFilterRequest)),
43    modifiers(&OptionalAuth),
44)]
45struct ApiDoc;
46
47struct OptionalAuth;
48
49impl utoipa::Modify for OptionalAuth {
50    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
51        let components = openapi
52            .components
53            .as_mut()
54            .expect("always set because we always have components above");
55        // Define basic_auth as http basic auth
56        components.add_security_scheme(
57            "basic_auth",
58            SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
59        );
60    }
61}
62
63#[derive(Deserialize, Clone, Debug)]
64#[serde(deny_unknown_fields)]
65pub struct HttpListenerParams {
66    #[serde(default = "HttpListenerParams::default_hostname")]
67    pub hostname: String,
68
69    #[serde(default = "HttpListenerParams::default_listen")]
70    pub listen: String,
71
72    #[serde(default)]
73    pub use_tls: bool,
74
75    #[serde(default)]
76    pub request_body_limit: Option<usize>,
77
78    #[serde(default)]
79    pub tls_certificate: Option<KeySource>,
80    #[serde(default)]
81    pub tls_private_key: Option<KeySource>,
82
83    #[serde(default = "CidrSet::default_trusted_hosts")]
84    pub trusted_hosts: CidrSet,
85}
86
87pub struct RouterAndDocs {
88    pub router: Router<AppState>,
89    pub docs: utoipa::openapi::OpenApi,
90}
91
92impl RouterAndDocs {
93    pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
94        let mut api_docs = ApiDoc::openapi();
95        api_docs.info.title = self.docs.info.title.to_string();
96        api_docs.merge(self.docs.clone());
97        api_docs.info.version = version_info::kumo_version().to_string();
98        api_docs.info.license = Some(
99            utoipa::openapi::LicenseBuilder::new()
100                .name("Apache-2.0")
101                .build(),
102        );
103
104        api_docs
105    }
106}
107
108#[derive(Clone)]
109pub struct AppState {
110    params: HttpListenerParams,
111    local_addr: SocketAddr,
112}
113
114impl AppState {
115    pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
116        self.params.trusted_hosts.contains(addr)
117    }
118
119    pub fn params(&self) -> &HttpListenerParams {
120        &self.params
121    }
122
123    pub fn local_addr(&self) -> &SocketAddr {
124        &self.local_addr
125    }
126}
127
128impl HttpListenerParams {
129    fn default_listen() -> String {
130        "127.0.0.1:8000".to_string()
131    }
132
133    fn default_hostname() -> String {
134        gethostname::gethostname()
135            .to_str()
136            .unwrap_or("localhost")
137            .to_string()
138    }
139
140    // Note: it is possible to call
141    // server.with_graceful_shutdown(ShutdownSubscription::get().shutting_down)
142    // to have it listen for a shutdown request, but we're avoiding it:
143    // the request is the start of a shutdown and we need to allow a grace
144    // period for in-flight operations to complete.
145    // Some of those may require call backs to the HTTP endpoint
146    // if we're doing some kind of web hook like thing.
147    // So, for now at least, we'll have to manually verify if
148    // a request should proceed based on the results from the lifecycle
149    // module.
150    pub async fn start(
151        self,
152        router_and_docs: RouterAndDocs,
153        runtime: Option<tokio::runtime::Handle>,
154    ) -> anyhow::Result<()> {
155        let api_docs = router_and_docs.make_docs();
156
157        let compression_layer: CompressionLayer = CompressionLayer::new()
158            .deflate(true)
159            .gzip(true)
160            .quality(tower_http::CompressionLevel::Fastest);
161        let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
162
163        let socket = TcpListener::bind(&self.listen)
164            .with_context(|| format!("listen on {}", self.listen))?;
165        let addr = socket.local_addr()?;
166
167        let app_state = AppState {
168            params: self.clone(),
169            local_addr: addr.clone(),
170        };
171
172        let app = router_and_docs
173            .router
174            .layer(DefaultBodyLimit::max(
175                self.request_body_limit.unwrap_or(2 * 1024 * 1024),
176            ))
177            .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
178            .route(
179                "/api/admin/set_diagnostic_log_filter/v1",
180                post(set_diagnostic_log_filter_v1),
181            )
182            .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
183            .route("/api/admin/memory/stats", get(memory_stats))
184            .route("/metrics", get(report_metrics))
185            .route("/metrics.json", get(report_metrics_json))
186            // Require that all requests be authenticated as either coming
187            // from a trusted IP address, or with an authorization header
188            .route_layer(axum::middleware::from_fn_with_state(
189                app_state.clone(),
190                auth_middleware,
191            ))
192            .layer(compression_layer)
193            .layer(decompression_layer)
194            .layer(TraceLayer::new_for_http())
195            .layer(axum_client_ip::ClientIpSource::ConnectInfo.into_extension())
196            .with_state(app_state);
197
198        let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
199
200        // The logic below is a bit repeatey, but it is still fewer
201        // lines of magic than it would be to factor out into a
202        // generic function because of all of the trait bounds
203        // that it would require.
204        if self.use_tls {
205            let config = self.tls_config().await?;
206            tracing::info!("https listener on {addr:?}");
207            let server = axum_server::from_tcp_rustls(socket, config);
208            let serve = async move { server.serve(make_service).await };
209
210            if let Some(runtime) = runtime {
211                runtime.spawn(serve);
212            } else {
213                spawn(format!("https {addr:?}"), serve)?;
214            }
215        } else {
216            tracing::info!("http listener on {addr:?}");
217            let server = axum_server::from_tcp(socket);
218            let serve = async move { server.serve(make_service).await };
219            if let Some(runtime) = runtime {
220                runtime.spawn(serve);
221            } else {
222                spawn(format!("http {addr:?}"), serve)?;
223            }
224        }
225        Ok(())
226    }
227
228    async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
229        let config = crate::tls_helpers::make_server_config(
230            &self.hostname,
231            &self.tls_private_key,
232            &self.tls_certificate,
233            &None,
234        )
235        .await?;
236        Ok(RustlsConfig::from_config(config))
237    }
238}
239
240#[derive(Debug)]
241pub struct AppError {
242    pub err: anyhow::Error,
243    pub code: StatusCode,
244}
245
246impl AppError {
247    pub fn new(code: StatusCode, err: impl Into<String>) -> Self {
248        let err: String = err.into();
249        Self {
250            err: anyhow::anyhow!(err),
251            code,
252        }
253    }
254}
255
256// Tell axum how to convert `AppError` into a response.
257impl IntoResponse for AppError {
258    fn into_response(self) -> Response {
259        (self.code, format!("Error: {:#}", self.err)).into_response()
260    }
261}
262
263// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
264// `Result<_, AppError>`. That way you don't need to do that manually.
265impl<E> From<E> for AppError
266where
267    E: Into<anyhow::Error>,
268{
269    fn from(err: E) -> Self {
270        Self {
271            err: err.into(),
272            code: StatusCode::INTERNAL_SERVER_ERROR,
273        }
274    }
275}
276
277/// Allows the system operator to trigger a configuration epoch bump,
278/// which causes various configs that are using the Epoch strategy to
279/// be re-evaluated by triggering the appropriate callbacks.
280#[utoipa::path(
281    post,
282    tag="config",
283    path="/api/admin/bump-config-epoch",
284    responses(
285        (status=200, description = "bump successful")
286    ),
287)]
288async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
289    config::epoch::bump_current_epoch();
290    Ok(())
291}
292
293/// Returns information about the system memory usage in an unstructured
294/// human readable format.  The output is not machine parseable and may
295/// change without notice between versions of kumomta.
296#[utoipa::path(
297    get,
298    tag="memory",
299    path="/api/admin/memory/stats",
300    responses(
301        (status=200, description = "stats were returned")
302    ),
303)]
304async fn memory_stats(_: TrustedIpRequired) -> String {
305    use kumo_server_memory::NumBytes;
306    use std::fmt::Write;
307    let mut result = String::new();
308
309    let jstats = JemallocStats::collect();
310    writeln!(result, "{jstats:#?}").ok();
311
312    if let Ok((usage, limit)) = get_usage_and_limit() {
313        writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
314        writeln!(
315            result,
316            "soft limit = {:?}",
317            limit.soft_limit.map(NumBytes::from)
318        )
319        .ok();
320        writeln!(
321            result,
322            "hard limit = {:?}",
323            limit.hard_limit.map(NumBytes::from)
324        )
325        .ok();
326    }
327
328    let mut stats = tracking_stats();
329    writeln!(result, "live = {:?}", stats.live).ok();
330
331    if stats.top_callstacks.is_empty() {
332        write!(
333            result,
334            "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
335        )
336        .ok();
337    } else {
338        writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
339        write!(result, "\ntop call stacks:\n").ok();
340        for stack in &mut stats.top_callstacks {
341            writeln!(
342                result,
343                "sampled every {} allocations, estimated {} allocations of {} total bytes",
344                stack.stochastic_rate,
345                stack.count * stack.stochastic_rate,
346                stack.total_size * stack.stochastic_rate
347            )
348            .ok();
349            write!(result, "{:?}\n\n", stack.bt).ok();
350        }
351    }
352
353    result
354}
355
356#[derive(Deserialize)]
357struct PrometheusMetricsParams {
358    #[serde(default)]
359    prefix: Option<String>,
360}
361
362async fn report_metrics(
363    _: TrustedIpRequired,
364    Query(params): Query<PrometheusMetricsParams>,
365) -> impl IntoResponse {
366    StreamBodyAsOptions::new()
367        .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
368        .text(kumo_prometheus::registry::Registry::stream_text(
369            params.prefix.clone(),
370        ))
371}
372
373async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
374    StreamBodyAsOptions::new()
375        .content_type(HttpHeaderValue::from_static(
376            "application/json; charset=utf-8",
377        ))
378        .text(kumo_prometheus::registry::Registry::stream_json())
379}
380
381/// Changes the diagnostic log filter dynamically.
382/// See <https://docs.kumomta.com/reference/kumo/set_diagnostic_log_filter/>
383/// for more information on diagnostic log filters.
384#[utoipa::path(
385    post,
386    tag="logging",
387    path="/api/admin/set_diagnostic_log_filter/v1",
388    responses(
389        (status = 200, description = "Diagnostic level set successfully")
390    ),
391)]
392async fn set_diagnostic_log_filter_v1(
393    _: TrustedIpRequired,
394    // Note: Json<> must be last in the param list
395    Json(request): Json<SetDiagnosticFilterRequest>,
396) -> Result<(), AppError> {
397    set_diagnostic_log_filter(&request.filter)?;
398    Ok(())
399}