1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
use crate::diagnostic_logging::set_diagnostic_log_filter;
use anyhow::Context;
use axum::extract::{DefaultBodyLimit, Json, Query};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use axum_streams::{HttpHeaderValue, StreamBodyAsOptions};
use cidr_map::CidrSet;
use data_loader::KeySource;
use kumo_server_runtime::spawn;
use serde::Deserialize;
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::sync::Arc;
use tower_http::compression::CompressionLayer;
use tower_http::trace::TraceLayer;
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
use utoipa::OpenApi;
use utoipa_rapidoc::RapiDoc;
// Avoid referencing api types as crate::name in the utoipa macros,
// otherwise it generates namespaced names in the openapi.json, which
// in turn require annotating each and every struct with the namespace
// in order for the document to be valid.
use kumo_api_types::*;

pub mod auth;

use auth::*;

#[derive(OpenApi)]
#[openapi(
    info(license(name = "Apache-2.0")),
    paths(set_diagnostic_log_filter_v1, bump_config_epoch),
    // Indicate that all paths can accept http basic auth.
    // the "basic_auth" name corresponds with the scheme
    // defined by the OptionalAuth addon defined below
    security(
        ("basic_auth" = [""])
    ),
    components(schemas(SetDiagnosticFilterRequest)),
    modifiers(&OptionalAuth),
)]
struct ApiDoc;

struct OptionalAuth;

impl utoipa::Modify for OptionalAuth {
    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
        let components = openapi
            .components
            .as_mut()
            .expect("always set because we always have components above");
        // Define basic_auth as http basic auth
        components.add_security_scheme(
            "basic_auth",
            SecurityScheme::Http(Http::new(HttpAuthScheme::Basic)),
        );
    }
}

#[derive(Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct HttpListenerParams {
    #[serde(default = "HttpListenerParams::default_hostname")]
    pub hostname: String,

    #[serde(default = "HttpListenerParams::default_listen")]
    pub listen: String,

    #[serde(default)]
    pub use_tls: bool,

    #[serde(default)]
    pub request_body_limit: Option<usize>,

    #[serde(default)]
    pub tls_certificate: Option<KeySource>,
    #[serde(default)]
    pub tls_private_key: Option<KeySource>,

    #[serde(default = "CidrSet::default_trusted_hosts")]
    pub trusted_hosts: CidrSet,
}

pub struct RouterAndDocs {
    pub router: Router,
    pub docs: utoipa::openapi::OpenApi,
}

impl RouterAndDocs {
    pub fn make_docs(&self) -> utoipa::openapi::OpenApi {
        let mut api_docs = ApiDoc::openapi();
        api_docs.info.title = self.docs.info.title.to_string();
        api_docs.merge(self.docs.clone());
        api_docs.info.version = version_info::kumo_version().to_string();
        api_docs.info.license = Some(
            utoipa::openapi::LicenseBuilder::new()
                .name("Apache-2.0")
                .build(),
        );

        api_docs
    }
}

#[derive(Clone)]
pub struct AppState {
    trusted_hosts: Arc<CidrSet>,
}

impl AppState {
    pub fn is_trusted_host(&self, addr: IpAddr) -> bool {
        self.trusted_hosts.contains(addr)
    }
}

impl HttpListenerParams {
    fn default_listen() -> String {
        "127.0.0.1:8000".to_string()
    }

    fn default_hostname() -> String {
        gethostname::gethostname()
            .to_str()
            .unwrap_or("localhost")
            .to_string()
    }

    // Note: it is possible to call
    // server.with_graceful_shutdown(ShutdownSubscription::get().shutting_down)
    // to have it listen for a shutdown request, but we're avoiding it:
    // the request is the start of a shutdown and we need to allow a grace
    // period for in-flight operations to complete.
    // Some of those may require call backs to the HTTP endpoint
    // if we're doing some kind of web hook like thing.
    // So, for now at least, we'll have to manually verify if
    // a request should proceed based on the results from the lifecycle
    // module.
    pub async fn start(self, router_and_docs: RouterAndDocs) -> anyhow::Result<()> {
        let api_docs = router_and_docs.make_docs();

        let compression_layer: CompressionLayer = CompressionLayer::new()
            .deflate(true)
            .gzip(true)
            .quality(tower_http::CompressionLevel::Fastest);

        let app = router_and_docs
            .router
            .layer(DefaultBodyLimit::max(
                self.request_body_limit.unwrap_or(2 * 1024 * 1024),
            ))
            .merge(RapiDoc::with_openapi("/api-docs/openapi.json", api_docs).path("/rapidoc"))
            .route(
                "/api/admin/set_diagnostic_log_filter/v1",
                post(set_diagnostic_log_filter_v1),
            )
            .route("/api/admin/bump-config-epoch", post(bump_config_epoch))
            .route("/metrics", get(report_metrics))
            .route("/metrics.json", get(report_metrics_json))
            // Require that all requests be authenticated as either coming
            // from a trusted IP address, or with an authorization header
            .route_layer(axum::middleware::from_fn_with_state(
                AppState {
                    trusted_hosts: Arc::new(self.trusted_hosts.clone()),
                },
                auth_middleware,
            ))
            .layer(compression_layer)
            .layer(TraceLayer::new_for_http());
        let socket = TcpListener::bind(&self.listen)
            .with_context(|| format!("listen on {}", self.listen))?;
        let addr = socket.local_addr()?;

        if self.use_tls {
            let config = self.tls_config().await?;
            tracing::info!("https listener on {addr:?}");
            let server = axum_server::from_tcp_rustls(socket, config);
            spawn(format!("https {addr:?}"), async move {
                server
                    .serve(app.into_make_service_with_connect_info::<SocketAddr>())
                    .await
            })?;
        } else {
            tracing::info!("http listener on {addr:?}");
            let server = axum_server::from_tcp(socket);
            spawn(format!("http {addr:?}"), async move {
                server
                    .serve(app.into_make_service_with_connect_info::<SocketAddr>())
                    .await
            })?;
        }
        Ok(())
    }

    async fn tls_config(&self) -> anyhow::Result<RustlsConfig> {
        let config = crate::tls_helpers::make_server_config(
            &self.hostname,
            &self.tls_private_key,
            &self.tls_certificate,
        )
        .await?;
        Ok(RustlsConfig::from_config(config))
    }
}

#[derive(Debug)]
pub struct AppError(pub anyhow::Error);

// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Error: {:#}", self.0),
        )
            .into_response()
    }
}

// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
    E: Into<anyhow::Error>,
{
    fn from(err: E) -> Self {
        Self(err.into())
    }
}

/// Allows the system operator to trigger a configuration epoch bump,
/// which causes various configs that are using the Epoch strategy to
/// be re-evaluated by triggering the appropriate callbacks.
#[utoipa::path(
    post,
    tag="config",
    path="/api/admin/bump-config-epoch",
    responses(
        (status=200, description = "bump successful")
    ),
)]
async fn bump_config_epoch(_: TrustedIpRequired) -> Result<(), AppError> {
    config::epoch::bump_current_epoch();
    Ok(())
}

#[derive(Deserialize)]
struct PrometheusMetricsParams {
    #[serde(default)]
    prefix: Option<String>,
}

async fn report_metrics(
    _: TrustedIpRequired,
    Query(params): Query<PrometheusMetricsParams>,
) -> impl IntoResponse {
    StreamBodyAsOptions::new()
        .content_type(HttpHeaderValue::from_static("text/plain; charset=utf-8"))
        .text(kumo_prometheus::registry::Registry::stream_text(
            params.prefix.clone(),
        ))
}

async fn report_metrics_json(_: TrustedIpRequired) -> impl IntoResponse {
    StreamBodyAsOptions::new()
        .content_type(HttpHeaderValue::from_static(
            "application/json; charset=utf-8",
        ))
        .text(kumo_prometheus::registry::Registry::stream_json())
}

/// Changes the diagnostic log filter dynamically.
/// See <https://docs.kumomta.com/reference/kumo/set_diagnostic_log_filter/>
/// for more information on diagnostic log filters.
#[utoipa::path(
    post,
    tag="logging",
    path="/api/admin/set_diagnostic_log_filter/v1",
    responses(
        (status = 200, description = "Diagnostic level set successfully")
    ),
)]
async fn set_diagnostic_log_filter_v1(
    _: TrustedIpRequired,
    // Note: Json<> must be last in the param list
    Json(request): Json<SetDiagnosticFilterRequest>,
) -> Result<(), AppError> {
    set_diagnostic_log_filter(&request.filter)?;
    Ok(())
}