kumo_server_common/
lib.rs

1use config::{
2    any_err, decorate_callback_name, from_lua_value, get_or_create_module, load_config,
3    serialize_options, CallbackSignature,
4};
5use kumo_server_runtime::available_parallelism;
6use mlua::{Function, Lua, LuaSerdeExt, Value, Variadic};
7use mod_redis::RedisConnKey;
8use serde::{Deserialize, Serialize};
9use std::sync::atomic::AtomicUsize;
10
11pub mod acct;
12pub mod authn_authz;
13pub mod config_handle;
14pub mod diagnostic_logging;
15pub mod disk_space;
16pub mod http_server;
17pub mod log;
18pub mod nodeid;
19pub mod panic;
20pub mod start;
21pub mod tls_helpers;
22
23pub fn register(lua: &Lua) -> anyhow::Result<()> {
24    for func in [
25        mod_redis::register,
26        data_loader::register,
27        mod_digest::register,
28        mod_encode::register,
29        mod_aws_sigv4::register,
30        cidr_map::register,
31        domain_map::register,
32        mod_amqp::register,
33        mod_filesystem::register,
34        mod_file_type::register,
35        mod_http::register,
36        mod_regex::register,
37        mod_serde::register,
38        mod_sqlite::register,
39        mod_crypto::register,
40        mod_smtp_response_normalize::register,
41        mod_string::register,
42        mod_time::register,
43        mod_dns_resolver::register,
44        mod_kafka::register,
45        mod_memoize::register,
46        mod_mimepart::register,
47        mod_mpsc::register,
48        mod_uuid::register,
49        kumo_api_types::shaping::register,
50        regex_set_map::register,
51        crate::authn_authz::register,
52    ] {
53        func(lua)?;
54    }
55
56    let kumo_mod = get_or_create_module(lua, "kumo")?;
57    kumo_mod.set("version", version_info::kumo_version())?;
58
59    fn event_registrar_name(name: &str) -> String {
60        format!("kumomta-event-registrars-{name}")
61    }
62
63    // Record the call stack of the code calling kumo.on so that
64    // kumo.get_event_registrars can retrieve it later
65    fn register_event_caller(lua: &Lua, name: &str) -> mlua::Result<()> {
66        let decorated_name = event_registrar_name(name);
67        let mut call_stack = vec![];
68        for n in 1.. {
69            match lua.inspect_stack(n, |info| {
70                let source = info.source();
71                format!(
72                    "{}:{}",
73                    source
74                        .short_src
75                        .as_ref()
76                        .map(|b| b.to_string())
77                        .unwrap_or_else(String::new),
78                    info.current_line().unwrap_or(0)
79                )
80            }) {
81                Some(info) => {
82                    call_stack.push(info);
83                }
84                None => break,
85            }
86        }
87
88        let tbl: Value = lua.named_registry_value(&decorated_name)?;
89        match tbl {
90            Value::Nil => {
91                let tbl = lua.create_table()?;
92                tbl.set(1, call_stack)?;
93                lua.set_named_registry_value(&decorated_name, tbl)?;
94                Ok(())
95            }
96            Value::Table(tbl) => {
97                let len = tbl.raw_len();
98                tbl.set(len + 1, call_stack)?;
99                Ok(())
100            }
101            _ => Err(mlua::Error::external(format!(
102                "registry key for {decorated_name} has invalid type",
103            ))),
104        }
105    }
106
107    // Returns the list of call-stacks of the code that registered
108    // for a specific named event
109    kumo_mod.set(
110        "get_event_registrars",
111        lua.create_function(move |lua, name: String| {
112            let decorated_name = event_registrar_name(&name);
113            let value: Value = lua.named_registry_value(&decorated_name)?;
114            Ok(value)
115        })?,
116    )?;
117
118    kumo_mod.set(
119        "on",
120        lua.create_function(move |lua, (name, func): (String, Function)| {
121            let decorated_name = decorate_callback_name(&name);
122
123            if let Ok(current_event) = lua.globals().get::<String>("_KUMO_CURRENT_EVENT") {
124                if current_event != "main" {
125                    return Err(mlua::Error::external(format!(
126                        "Attempting to register an event handler via \
127                    `kumo.on('{name}', ...)` from within the event handler \
128                    '{current_event}'. You must move your event handler registration \
129                    so that it is setup directly when the policy is loaded \
130                    in order for it to consistently trigger and handle events."
131                    )));
132                }
133            }
134
135            register_event_caller(lua, &name)?;
136
137            if config::does_callback_allow_multiple(&name) {
138                let tbl: Value = lua.named_registry_value(&decorated_name)?;
139                return match tbl {
140                    Value::Nil => {
141                        let tbl = lua.create_table()?;
142                        tbl.set(1, func)?;
143                        lua.set_named_registry_value(&decorated_name, tbl)?;
144                        Ok(())
145                    }
146                    Value::Table(tbl) => {
147                        let len = tbl.raw_len();
148                        tbl.set(len + 1, func)?;
149                        Ok(())
150                    }
151                    _ => Err(mlua::Error::external(format!(
152                        "registry key for {decorated_name} has invalid type",
153                    ))),
154                };
155            }
156
157            let existing: Value = lua.named_registry_value(&decorated_name)?;
158            match existing {
159                Value::Nil => {}
160                Value::Function(func) => {
161                    let info = func.info();
162                    let src = info.source.unwrap_or_else(|| "?".into());
163                    let line = info.line_defined.unwrap_or(0);
164                    return Err(mlua::Error::external(format!(
165                        "{name} event already has a handler defined at {src}:{line}"
166                    )));
167                }
168                _ => {
169                    return Err(mlua::Error::external(format!(
170                        "{name} event already has a handler"
171                    )));
172                }
173            }
174
175            lua.set_named_registry_value(&decorated_name, func)?;
176            Ok(())
177        })?,
178    )?;
179
180    kumo_mod.set(
181        "set_diagnostic_log_filter",
182        lua.create_function(move |_, filter: String| {
183            diagnostic_logging::set_diagnostic_log_filter(&filter).map_err(any_err)
184        })?,
185    )?;
186
187    fn variadic_to_string(args: Variadic<Value>) -> String {
188        let mut output = String::new();
189        for (idx, item) in args.into_iter().enumerate() {
190            if idx > 0 {
191                output.push(' ');
192            }
193
194            match item {
195                Value::String(s) => match s.to_str() {
196                    Ok(s) => output.push_str(&s),
197                    Err(_) => {
198                        let item = s.to_string_lossy();
199                        output.push_str(&item);
200                    }
201                },
202                item => match item.to_string() {
203                    Ok(s) => output.push_str(&s),
204                    Err(_) => output.push_str(&format!("{item:?}")),
205                },
206            }
207        }
208        output
209    }
210
211    fn get_caller(lua: &Lua) -> String {
212        match lua.inspect_stack(1, |info| {
213            let source = info.source();
214            let file_name = source
215                .short_src
216                .as_ref()
217                .map(|b| b.to_string())
218                .unwrap_or_else(String::new);
219            // Lua returns the somewhat obnoxious `[string "source.lua"]`
220            // Let's fix that up to be a bit nicer
221            let file_name = match file_name.strip_prefix("[string \"") {
222                Some(name) => name.strip_suffix("\"]").unwrap_or(name),
223                None => &file_name,
224            };
225
226            format!("{file_name}:{}", info.current_line().unwrap_or(0))
227        }) {
228            Some(info) => info,
229            None => "?".to_string(),
230        }
231    }
232
233    kumo_mod.set(
234        "log_error",
235        lua.create_function(move |lua, args: Variadic<Value>| {
236            if tracing::event_enabled!(target: "lua", tracing::Level::ERROR) {
237                let src = get_caller(lua);
238                tracing::error!(target: "lua", "{src}: {}", variadic_to_string(args));
239            }
240            Ok(())
241        })?,
242    )?;
243    kumo_mod.set(
244        "log_info",
245        lua.create_function(move |lua, args: Variadic<Value>| {
246            if tracing::event_enabled!(target: "lua", tracing::Level::INFO) {
247                let src = get_caller(lua);
248                tracing::info!(target: "lua", "{src}: {}", variadic_to_string(args));
249            }
250            Ok(())
251        })?,
252    )?;
253    kumo_mod.set(
254        "log_warn",
255        lua.create_function(move |lua, args: Variadic<Value>| {
256            if tracing::event_enabled!(target: "lua", tracing::Level::WARN) {
257                let src = get_caller(lua);
258                tracing::warn!(target: "lua", "{src}: {}", variadic_to_string(args));
259            }
260            Ok(())
261        })?,
262    )?;
263    kumo_mod.set(
264        "log_debug",
265        lua.create_function(move |lua, args: Variadic<Value>| {
266            if tracing::event_enabled!(target: "lua", tracing::Level::DEBUG) {
267                let src = get_caller(lua);
268                tracing::debug!(target: "lua", "{src}: {}", variadic_to_string(args));
269            }
270            Ok(())
271        })?,
272    )?;
273
274    kumo_mod.set(
275        "set_max_spare_lua_contexts",
276        lua.create_function(move |_, limit: usize| {
277            config::set_max_spare(limit);
278            Ok(())
279        })?,
280    )?;
281
282    kumo_mod.set(
283        "set_max_lua_context_use_count",
284        lua.create_function(move |_, limit: usize| {
285            config::set_max_use(limit);
286            Ok(())
287        })?,
288    )?;
289
290    kumo_mod.set(
291        "set_max_lua_context_age",
292        lua.create_function(move |_, limit: usize| {
293            config::set_max_age(limit);
294            Ok(())
295        })?,
296    )?;
297
298    kumo_mod.set(
299        "set_lua_gc_on_put",
300        lua.create_function(move |_, enable: u8| {
301            config::set_gc_on_put(enable);
302            Ok(())
303        })?,
304    )?;
305
306    kumo_mod.set(
307        "set_lruttl_cache_capacity",
308        lua.create_function(move |_, (name, capacity): (String, usize)| {
309            if lruttl::set_cache_capacity(&name, capacity) {
310                Ok(())
311            } else {
312                Err(mlua::Error::external(format!(
313                    "could not set capacity for cache {name} \
314                    as that is not a pre-defined lruttl cache name"
315                )))
316            }
317        })?,
318    )?;
319
320    kumo_mod.set(
321        "set_config_monitor_globs",
322        lua.create_function(move |_, globs: Vec<String>| {
323            config::epoch::set_globs(globs).map_err(any_err)?;
324            Ok(())
325        })?,
326    )?;
327    kumo_mod.set(
328        "eval_config_monitor_globs",
329        lua.create_async_function(|_, _: ()| async move {
330            config::epoch::eval_globs().await.map_err(any_err)
331        })?,
332    )?;
333    kumo_mod.set(
334        "bump_config_epoch",
335        lua.create_function(move |_, _: ()| {
336            config::epoch::bump_current_epoch();
337            Ok(())
338        })?,
339    )?;
340
341    kumo_mod.set(
342        "available_parallelism",
343        lua.create_function(move |_, _: ()| available_parallelism().map_err(any_err))?,
344    )?;
345
346    kumo_mod.set(
347        "set_memory_hard_limit",
348        lua.create_function(move |_, limit: usize| {
349            kumo_server_memory::set_hard_limit(limit);
350            Ok(())
351        })?,
352    )?;
353
354    kumo_mod.set(
355        "set_memory_low_thresh",
356        lua.create_function(move |_, limit: usize| {
357            kumo_server_memory::set_low_memory_thresh(limit);
358            Ok(())
359        })?,
360    )?;
361
362    kumo_mod.set(
363        "set_memory_soft_limit",
364        lua.create_function(move |_, limit: usize| {
365            kumo_server_memory::set_soft_limit(limit);
366            Ok(())
367        })?,
368    )?;
369
370    kumo_mod.set(
371        "get_memory_hard_limit",
372        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_hard_limit()))?,
373    )?;
374
375    kumo_mod.set(
376        "get_memory_soft_limit",
377        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_soft_limit()))?,
378    )?;
379
380    kumo_mod.set(
381        "get_memory_low_thresh",
382        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_low_memory_thresh()))?,
383    )?;
384
385    kumo_mod.set(
386        "configure_redis_throttles",
387        lua.create_async_function(|lua, params: Value| async move {
388            let key: RedisConnKey = from_lua_value(&lua, params)?;
389            let conn = key.open().map_err(any_err)?;
390            conn.ping().await.map_err(any_err)?;
391            throttle::use_redis(conn).await.map_err(any_err)
392        })?,
393    )?;
394
395    kumo_mod.set(
396        "traceback",
397        lua.create_function(move |lua: &Lua, level: usize| {
398            #[derive(Debug, Serialize)]
399            struct Frame {
400                event: String,
401                name: Option<String>,
402                name_what: Option<String>,
403                source: Option<String>,
404                short_src: Option<String>,
405                line_defined: Option<usize>,
406                last_line_defined: Option<usize>,
407                what: &'static str,
408                curr_line: Option<usize>,
409                is_tail_call: bool,
410            }
411
412            let mut frames = vec![];
413            for n in level.. {
414                match lua.inspect_stack(n, |info| {
415                    let source = info.source();
416                    let names = info.names();
417                    Frame {
418                        curr_line: info.current_line(),
419                        is_tail_call: info.is_tail_call(),
420                        event: format!("{:?}", info.event()),
421                        last_line_defined: source.last_line_defined,
422                        line_defined: source.line_defined,
423                        name: names.name.as_ref().map(|b| b.to_string()),
424                        name_what: names.name_what.as_ref().map(|b| b.to_string()),
425                        source: source.source.as_ref().map(|b| b.to_string()),
426                        short_src: source.short_src.as_ref().map(|b| b.to_string()),
427                        what: source.what,
428                    }
429                }) {
430                    Some(frame) => {
431                        frames.push(frame);
432                    }
433                    None => break,
434                }
435            }
436
437            lua.to_value(&frames)
438        })?,
439    )?;
440
441    // TODO: options like restarting on error, delay between
442    // restarts and so on
443    #[derive(Deserialize, Debug)]
444    struct TaskParams {
445        event_name: String,
446        args: Vec<serde_json::Value>,
447    }
448
449    impl TaskParams {
450        async fn run(&self) -> anyhow::Result<()> {
451            let mut config = load_config().await?;
452
453            let sig = CallbackSignature::<Value, ()>::new(self.event_name.to_string());
454
455            config
456                .convert_args_and_call_callback(&sig, &self.args)
457                .await?;
458
459            config.put();
460
461            Ok(())
462        }
463    }
464
465    kumo_mod.set(
466        "spawn_task",
467        lua.create_function(|lua, params: Value| {
468            let params: TaskParams = lua.from_value(params)?;
469
470            if !config::is_validating() {
471                std::thread::Builder::new()
472                    .name(format!("spawned-task-{}", params.event_name))
473                    .spawn(move || {
474                        let runtime = tokio::runtime::Builder::new_current_thread()
475                            .enable_io()
476                            .enable_time()
477                            .on_thread_park(kumo_server_memory::purge_thread_cache)
478                            .build()
479                            .unwrap();
480                        let event_name = params.event_name.clone();
481
482                        let result = runtime.block_on(async move { params.run().await });
483                        if let Err(err) = result {
484                            tracing::error!("Error while dispatching {event_name}: {err:#}");
485                        }
486                    })?;
487            }
488
489            Ok(())
490        })?,
491    )?;
492
493    kumo_mod.set(
494        "spawn_thread_pool",
495        lua.create_function(|lua, params: Value| {
496            #[derive(Deserialize, Debug)]
497            struct ThreadPoolParams {
498                name: String,
499                num_threads: usize,
500            }
501
502            let params: ThreadPoolParams = lua.from_value(params)?;
503            let num_threads = AtomicUsize::new(params.num_threads);
504
505            if !config::is_validating() {
506                // Create the runtime. We don't need to hold on
507                // to it here, as it will be kept alive in the
508                // runtimes map in that crate
509                let _runtime = kumo_server_runtime::Runtime::new(
510                    &params.name,
511                    |_| params.num_threads,
512                    &num_threads,
513                )
514                .map_err(any_err)?;
515            }
516
517            Ok(())
518        })?,
519    )?;
520
521    kumo_mod.set(
522        "validation_failed",
523        lua.create_function(|_, ()| {
524            config::set_validation_failed();
525            Ok(())
526        })?,
527    )?;
528
529    kumo_mod.set(
530        "enable_memory_callstack_tracking",
531        lua.create_function(|_, enable: bool| {
532            kumo_server_memory::set_tracking_callstacks(enable);
533            Ok(())
534        })?,
535    )?;
536
537    // This function is intended for debugging and testing purposes only.
538    // It is potentially very expensive on a production system with many
539    // thousands of queues.
540    kumo_mod.set(
541        "prometheus_metrics",
542        lua.create_async_function(|lua, ()| async move {
543            use tokio_stream::StreamExt;
544            let mut json_text = String::new();
545            let mut stream = kumo_prometheus::registry::Registry::stream_json();
546            while let Some(text) = stream.next().await {
547                json_text.push_str(&text);
548            }
549            let value: serde_json::Value = serde_json::from_str(&json_text).map_err(any_err)?;
550            lua.to_value_with(&value, serialize_options())
551        })?,
552    )?;
553
554    Ok(())
555}