kumo_server_common/http_server/
mod.rs

1use crate::diagnostic_logging::set_diagnostic_log_filter;
2use anyhow::Context;
3use axum::extract::{DefaultBodyLimit, Json, Query};
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use axum::routing::{get, post};
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
10use cidr_map::CidrSet;
11use data_loader::KeySource;
12use kumo_server_memory::{get_usage_and_limit, tracking_stats, JemallocStats};
13use kumo_server_runtime::spawn;
14use serde::Deserialize;
15use std::net::{IpAddr, SocketAddr, TcpListener};
16use std::sync::Arc;
17use tower_http::compression::CompressionLayer;
18use tower_http::decompression::RequestDecompressionLayer;
19use tower_http::trace::TraceLayer;
20use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
21use utoipa::OpenApi;
22use utoipa_rapidoc::RapiDoc;
23// Avoid referencing api types as crate::name in the utoipa macros,
24// otherwise it generates namespaced names in the openapi.json, which
25// in turn require annotating each and every struct with the namespace
26// in order for the document to be valid.
27use kumo_api_types::*;
28
29pub mod auth;
30
31use auth::*;
32
33#[derive(OpenApi)]
34#[openapi(
35    info(license(name = "Apache-2.0")),
36    paths(set_diagnostic_log_filter_v1, bump_config_epoch),
37    // Indicate that all paths can accept http basic auth.
38    // the "basic_auth" name corresponds with the scheme
39    // defined by the OptionalAuth addon defined below
40    security(
41        ("basic_auth" = [""])
42    ),
43    components(schemas(SetDiagnosticFilterRequest)),
44    modifiers(&OptionalAuth),
45)]
46struct ApiDoc;
47
48struct OptionalAuth;
49
50impl utoipa::Modify for OptionalAuth {
51    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
52        let components = openapi
53            .components
54            .as_mut()
55            .expect("always set because we always have components above");
56        // Define basic_auth as http basic auth
57        components.add_security_scheme(
58            "basic_auth",
59            SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
60        );
61    }
62}
63
64#[derive(Deserialize, Clone, Debug)]
65#[serde(deny_unknown_fields)]
66pub struct HttpListenerParams {
67    #[serde(default = "HttpListenerParams::default_hostname")]
68    pub hostname: String,
69
70    #[serde(default = "HttpListenerParams::default_listen")]
71    pub listen: String,
72
73    #[serde(default)]
74    pub use_tls: bool,
75
76    #[serde(default)]
77    pub request_body_limit: Option<usize>,
78
79    #[serde(default)]
80    pub tls_certificate: Option<KeySource>,
81    #[serde(default)]
82    pub tls_private_key: Option<KeySource>,
83
84    #[serde(default = "CidrSet::default_trusted_hosts")]
85    pub trusted_hosts: CidrSet,
86}
87
88pub struct RouterAndDocs {
89    pub router: Router,
90    pub docs: utoipa::openapi::OpenApi,
91}
92
93impl RouterAndDocs {
94    pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
95        let mut api_docs = ApiDoc::openapi();
96        api_docs.info.title = self.docs.info.title.to_string();
97        api_docs.merge(self.docs.clone());
98        api_docs.info.version = version_info::kumo_version().to_string();
99        api_docs.info.license = Some(
100            utoipa::openapi::LicenseBuilder::new()
101                .name("Apache-2.0")
102                .build(),
103        );
104
105        api_docs
106    }
107}
108
109#[derive(Clone)]
110pub struct AppState {
111    trusted_hosts: Arc<CidrSet>,
112}
113
114impl AppState {
115    pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
116        self.trusted_hosts.contains(addr)
117    }
118}
119
120impl HttpListenerParams {
121    fn default_listen() -> String {
122        "127.0.0.1:8000".to_string()
123    }
124
125    fn default_hostname() -> String {
126        gethostname::gethostname()
127            .to_str()
128            .unwrap_or("localhost")
129            .to_string()
130    }
131
132    // Note: it is possible to call
133    // server.with_graceful_shutdown(ShutdownSubscription::get().shutting_down)
134    // to have it listen for a shutdown request, but we're avoiding it:
135    // the request is the start of a shutdown and we need to allow a grace
136    // period for in-flight operations to complete.
137    // Some of those may require call backs to the HTTP endpoint
138    // if we're doing some kind of web hook like thing.
139    // So, for now at least, we'll have to manually verify if
140    // a request should proceed based on the results from the lifecycle
141    // module.
142    pub async fn start(
143        self,
144        router_and_docs: RouterAndDocs,
145        runtime: Option<tokio::runtime::Handle>,
146    ) -> anyhow::Result<()> {
147        let api_docs = router_and_docs.make_docs();
148
149        let compression_layer: CompressionLayer = CompressionLayer::new()
150            .deflate(true)
151            .gzip(true)
152            .quality(tower_http::CompressionLevel::Fastest);
153        let decompression_layer = RequestDecompressionLayer::new().deflate(true).gzip(true);
154        let app = router_and_docs
155            .router
156            .layer(DefaultBodyLimit::max(
157                self.request_body_limit.unwrap_or(2 * 1024 * 1024),
158            ))
159            .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
160            .route(
161                "/api/admin/set_diagnostic_log_filter/v1",
162                post(set_diagnostic_log_filter_v1),
163            )
164            .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
165            .route("/api/admin/memory/stats", get(memory_stats))
166            .route("/metrics", get(report_metrics))
167            .route("/metrics.json", get(report_metrics_json))
168            // Require that all requests be authenticated as either coming
169            // from a trusted IP address, or with an authorization header
170            .route_layer(axum::middleware::from_fn_with_state(
171                AppState {
172                    trusted_hosts: Arc::new(self.trusted_hosts.clone()),
173                },
174                auth_middleware,
175            ))
176            .layer(compression_layer)
177            .layer(decompression_layer)
178            .layer(TraceLayer::new_for_http());
179        let socket = TcpListener::bind(&self.listen)
180            .with_context(|| format!("listen on {}", self.listen))?;
181        let addr = socket.local_addr()?;
182
183        let make_service = app.into_make_service_with_connect_info::<SocketAddr>();
184
185        // The logic below is a bit repeatey, but it is still fewer
186        // lines of magic than it would be to factor out into a
187        // generic function because of all of the trait bounds
188        // that it would require.
189        if self.use_tls {
190            let config = self.tls_config().await?;
191            tracing::info!("https listener on {addr:?}");
192            let server = axum_server::from_tcp_rustls(socket, config);
193            let serve = async move { server.serve(make_service).await };
194
195            if let Some(runtime) = runtime {
196                runtime.spawn(serve);
197            } else {
198                spawn(format!("https {addr:?}"), serve)?;
199            }
200        } else {
201            tracing::info!("http listener on {addr:?}");
202            let server = axum_server::from_tcp(socket);
203            let serve = async move { server.serve(make_service).await };
204            if let Some(runtime) = runtime {
205                runtime.spawn(serve);
206            } else {
207                spawn(format!("http {addr:?}"), serve)?;
208            }
209        }
210        Ok(())
211    }
212
213    async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
214        let config = crate::tls_helpers::make_server_config(
215            &self.hostname,
216            &self.tls_private_key,
217            &self.tls_certificate,
218            &None,
219        )
220        .await?;
221        Ok(RustlsConfig::from_config(config))
222    }
223}
224
225#[derive(Debug)]
226pub struct AppError(pub anyhow::Error);
227
228// Tell axum how to convert `AppError` into a response.
229impl IntoResponse for AppError {
230    fn into_response(self) -> Response {
231        (
232            StatusCode::INTERNAL_SERVER_ERROR,
233            format!("Error: {:#}", self.0),
234        )
235            .into_response()
236    }
237}
238
239// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
240// `Result<_, AppError>`. That way you don't need to do that manually.
241impl<E> From<E> for AppError
242where
243    E: Into<anyhow::Error>,
244{
245    fn from(err: E) -> Self {
246        Self(err.into())
247    }
248}
249
250/// Allows the system operator to trigger a configuration epoch bump,
251/// which causes various configs that are using the Epoch strategy to
252/// be re-evaluated by triggering the appropriate callbacks.
253#[utoipa::path(
254    post,
255    tag="config",
256    path="/api/admin/bump-config-epoch",
257    responses(
258        (status=200, description = "bump successful")
259    ),
260)]
261async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
262    config::epoch::bump_current_epoch();
263    Ok(())
264}
265
266/// Returns information about the system memory usage in an unstructured
267/// human readable format.  The output is not machine parseable and may
268/// change without notice between versions of kumomta.
269#[utoipa::path(
270    get,
271    tag="memory",
272    path="/api/admin/memory/stats",
273    responses(
274        (status=200, description = "stats were returned")
275    ),
276)]
277async fn memory_stats(_: TrustedIpRequired) -> String {
278    use kumo_server_memory::NumBytes;
279    use std::fmt::Write;
280    let mut result = String::new();
281
282    let jstats = JemallocStats::collect();
283    writeln!(result, "{jstats:#?}").ok();
284
285    if let Ok((usage, limit)) = get_usage_and_limit() {
286        writeln!(result, "RSS = {:?}", NumBytes::from(usage.bytes)).ok();
287        writeln!(
288            result,
289            "soft limit = {:?}",
290            limit.soft_limit.map(NumBytes::from)
291        )
292        .ok();
293        writeln!(
294            result,
295            "hard limit = {:?}",
296            limit.hard_limit.map(NumBytes::from)
297        )
298        .ok();
299    }
300
301    let mut stats = tracking_stats();
302    writeln!(result, "live = {:?}", stats.live).ok();
303
304    if stats.top_callstacks.is_empty() {
305        write!(
306            result,
307            "\nuse kumo.enable_memory_callstack_tracking(true) to enable additional stats\n"
308        )
309        .ok();
310    } else {
311        writeln!(result, "small_threshold = {:?}", stats.small_threshold).ok();
312        write!(result, "\ntop call stacks:\n").ok();
313        for stack in &mut stats.top_callstacks {
314            writeln!(
315                result,
316                "sampled every {} allocations, estimated {} allocations of {} total bytes",
317                stack.stochastic_rate,
318                stack.count * stack.stochastic_rate,
319                stack.total_size * stack.stochastic_rate
320            )
321            .ok();
322            write!(result, "{:?}\n\n", stack.bt).ok();
323        }
324    }
325
326    result
327}
328
329#[derive(Deserialize)]
330struct PrometheusMetricsParams {
331    #[serde(default)]
332    prefix: Option<String>,
333}
334
335async fn report_metrics(
336    _: TrustedIpRequired,
337    Query(params): Query<PrometheusMetricsParams>,
338) -> impl IntoResponse {
339    StreamBodyAsOptions::new()
340        .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
341        .text(kumo_prometheus::registry::Registry::stream_text(
342            params.prefix.clone(),
343        ))
344}
345
346async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
347    StreamBodyAsOptions::new()
348        .content_type(HttpHeaderValue::from_static(
349            "application/json; charset=utf-8",
350        ))
351        .text(kumo_prometheus::registry::Registry::stream_json())
352}
353
354/// Changes the diagnostic log filter dynamically.
355/// See <https://docs.kumomta.com/reference/kumo/set_diagnostic_log_filter/>
356/// for more information on diagnostic log filters.
357#[utoipa::path(
358    post,
359    tag="logging",
360    path="/api/admin/set_diagnostic_log_filter/v1",
361    responses(
362        (status = 200, description = "Diagnostic level set successfully")
363    ),
364)]
365async fn set_diagnostic_log_filter_v1(
366    _: TrustedIpRequired,
367    // Note: Json<> must be last in the param list
368    Json(request): Json<SetDiagnosticFilterRequest>,
369) -> Result<(), AppError> {
370    set_diagnostic_log_filter(&request.filter)?;
371    Ok(())
372}