use crate::diagnostic_logging::set_diagnostic_log_filter;
use anyhow::Context;
use axum::extract::{DefaultBodyLimit, Json, Query};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
use cidr_map::CidrSet;
use data_loader::KeySource;
use kumo_server_runtime::spawn;
use serde::Deserialize;
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::sync::Arc;
use tower_http::compression::CompressionLayer;
use tower_http::trace::TraceLayer;
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
use utoipa::OpenApi;
use utoipa_rapidoc::RapiDoc;
use kumo_api_types::*;
pub mod auth;
use auth::*;
#[derive(OpenApi)]
#[openapi(
info(license(name = "Apache-2.0")),
paths(set_diagnostic_log_filter_v1, bump_config_epoch),
security(
("basic_auth" = [""])
),
components(schemas(SetDiagnosticFilterRequest)),
modifiers(&OptionalAuth),
)]
struct ApiDoc;
struct OptionalAuth;
impl utoipa::Modify for OptionalAuth {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
let components = openapi
.components
.as_mut()
.expect("always set because we always have components above");
components.add_security_scheme(
"basic_auth",
SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
);
}
}
#[derive(Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct HttpListenerParams {
#[serde(default = "HttpListenerParams::default_hostname")]
pub hostname: String,
#[serde(default = "HttpListenerParams::default_listen")]
pub listen: String,
#[serde(default)]
pub use_tls: bool,
#[serde(default)]
pub request_body_limit: Option<usize>,
#[serde(default)]
pub tls_certificate: Option<KeySource>,
#[serde(default)]
pub tls_private_key: Option<KeySource>,
#[serde(default = "CidrSet::default_trusted_hosts")]
pub trusted_hosts: CidrSet,
}
pub struct RouterAndDocs {
pub router: Router,
pub docs: utoipa::openapi::OpenApi,
}
impl RouterAndDocs {
pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
let mut api_docs = ApiDoc::openapi();
api_docs.info.title = self.docs.info.title.to_string();
api_docs.merge(self.docs.clone());
api_docs.info.version = version_info::kumo_version().to_string();
api_docs.info.license = Some(
utoipa::openapi::LicenseBuilder::new()
.name("Apache-2.0")
.build(),
);
api_docs
}
}
#[derive(Clone)]
pub struct AppState {
trusted_hosts: Arc<CidrSet>,
}
impl AppState {
pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
self.trusted_hosts.contains(addr)
}
}
impl HttpListenerParams {
fn default_listen() -> String {
"127.0.0.1:8000".to_string()
}
fn default_hostname() -> String {
gethostname::gethostname()
.to_str()
.unwrap_or("localhost")
.to_string()
}
pub async fn start(self, router_and_docs: RouterAndDocs) -> anyhow::Result<()> {
let api_docs = router_and_docs.make_docs();
let compression_layer: CompressionLayer = CompressionLayer::new()
.deflate(true)
.gzip(true)
.quality(tower_http::CompressionLevel::Fastest);
let app = router_and_docs
.router
.layer(DefaultBodyLimit::max(
self.request_body_limit.unwrap_or(2 * 1024 * 1024),
))
.merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
.route(
"/api/admin/set_diagnostic_log_filter/v1",
post(set_diagnostic_log_filter_v1),
)
.route("/api/admin/bump-config-epoch", post(bump_config_epoch))
.route("/metrics", get(report_metrics))
.route("/metrics.json", get(report_metrics_json))
.route_layer(axum::middleware::from_fn_with_state(
AppState {
trusted_hosts: Arc::new(self.trusted_hosts.clone()),
},
auth_middleware,
))
.layer(compression_layer)
.layer(TraceLayer::new_for_http());
let socket = TcpListener::bind(&self.listen)
.with_context(|| format!("listen on {}", self.listen))?;
let addr = socket.local_addr()?;
if self.use_tls {
let config = self.tls_config().await?;
tracing::info!("https listener on {addr:?}");
let server = axum_server::from_tcp_rustls(socket, config);
spawn(format!("https {addr:?}"), async move {
server
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
})?;
} else {
tracing::info!("http listener on {addr:?}");
let server = axum_server::from_tcp(socket);
spawn(format!("http {addr:?}"), async move {
server
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
})?;
}
Ok(())
}
async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
let config = crate::tls_helpers::make_server_config(
&self.hostname,
&self.tls_private_key,
&self.tls_certificate,
)
.await?;
Ok(RustlsConfig::from_config(config))
}
}
#[derive(Debug)]
pub struct AppError(pub anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Error: {:#}", self.0),
)
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}
#[utoipa::path(
post,
tag="config",
path="/api/admin/bump-config-epoch",
responses(
(status=200, description = "bump successful")
),
)]
async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
config::epoch::bump_current_epoch();
Ok(())
}
#[derive(Deserialize)]
struct PrometheusMetricsParams {
#[serde(default)]
prefix: Option<String>,
}
async fn report_metrics(
_: TrustedIpRequired,
Query(params): Query<PrometheusMetricsParams>,
) -> impl IntoResponse {
StreamBodyAsOptions::new()
.content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
.text(kumo_prometheus::registry::Registry::stream_text(
params.prefix.clone(),
))
}
async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
StreamBodyAsOptions::new()
.content_type(HttpHeaderValue::from_static(
"application/json; charset=utf-8",
))
.text(kumo_prometheus::registry::Registry::stream_json())
}
#[utoipa::path(
post,
tag="logging",
path="/api/admin/set_diagnostic_log_filter/v1",
responses(
(status = 200, description = "Diagnostic level set successfully")
),
)]
async fn set_diagnostic_log_filter_v1(
_: TrustedIpRequired,
Json(request): Json<SetDiagnosticFilterRequest>,
) -> Result<(), AppError> {
set_diagnostic_log_filter(&request.filter)?;
Ok(())
}