kumo_server_common/
lib.rs

1use bstr::ByteSlice;
2use config::{
3    any_err, decorate_callback_name, from_lua_value, get_or_create_module, load_config,
4    serialize_options, CallbackSignature,
5};
6use kumo_server_runtime::available_parallelism;
7use mlua::{Function, Lua, LuaSerdeExt, Value, Variadic};
8use mod_redis::RedisConnKey;
9use serde::{Deserialize, Serialize};
10use std::sync::atomic::AtomicUsize;
11
12pub mod acct;
13pub mod authn_authz;
14pub mod config_handle;
15pub mod diagnostic_logging;
16pub mod disk_space;
17pub mod http_server;
18pub mod log;
19pub mod nodeid;
20pub mod panic;
21pub mod start;
22pub mod tls_helpers;
23
24pub fn register(lua: &Lua) -> anyhow::Result<()> {
25    for func in [
26        mod_redis::register,
27        data_loader::register,
28        mod_digest::register,
29        mod_encode::register,
30        mod_aws_sigv4::register,
31        cidr_map::register,
32        domain_map::register,
33        mod_amqp::register,
34        mod_filesystem::register,
35        mod_file_type::register,
36        mod_http::register,
37        mod_regex::register,
38        mod_serde::register,
39        mod_sqlite::register,
40        mod_crypto::register,
41        mod_smtp_response_normalize::register,
42        kumo_jsonl::lua::register,
43        mod_string::register,
44        mod_time::register,
45        mod_dns_resolver::register,
46        mod_kafka::register,
47        mod_memoize::register,
48        mod_mimepart::register,
49        mod_mpsc::register,
50        mod_nats::register,
51        mod_uuid::register,
52        kumo_api_types::shaping::register,
53        regex_set_map::register,
54        crate::authn_authz::register,
55    ] {
56        func(lua)?;
57    }
58
59    let kumo_mod = get_or_create_module(lua, "kumo")?;
60    kumo_mod.set("version", version_info::kumo_version())?;
61
62    fn event_registrar_name(name: &str) -> String {
63        format!("kumomta-event-registrars-{name}")
64    }
65
66    // Record the call stack of the code calling kumo.on so that
67    // kumo.get_event_registrars can retrieve it later
68    fn register_event_caller(lua: &Lua, name: &str) -> mlua::Result<()> {
69        let decorated_name = event_registrar_name(name);
70        let mut call_stack = vec![];
71        for n in 1.. {
72            match lua.inspect_stack(n, |info| {
73                let source = info.source();
74                format!(
75                    "{}:{}",
76                    source
77                        .short_src
78                        .as_ref()
79                        .map(|b| b.to_string())
80                        .unwrap_or_else(String::new),
81                    info.current_line().unwrap_or(0)
82                )
83            }) {
84                Some(info) => {
85                    call_stack.push(info);
86                }
87                None => break,
88            }
89        }
90
91        let tbl: Value = lua.named_registry_value(&decorated_name)?;
92        match tbl {
93            Value::Nil => {
94                let tbl = lua.create_table()?;
95                tbl.set(1, call_stack)?;
96                lua.set_named_registry_value(&decorated_name, tbl)?;
97                Ok(())
98            }
99            Value::Table(tbl) => {
100                let len = tbl.raw_len();
101                tbl.set(len + 1, call_stack)?;
102                Ok(())
103            }
104            _ => Err(mlua::Error::external(format!(
105                "registry key for {decorated_name} has invalid type",
106            ))),
107        }
108    }
109
110    // Returns the list of call-stacks of the code that registered
111    // for a specific named event
112    kumo_mod.set(
113        "get_event_registrars",
114        lua.create_function(move |lua, name: String| {
115            let decorated_name = event_registrar_name(&name);
116            let value: Value = lua.named_registry_value(&decorated_name)?;
117            Ok(value)
118        })?,
119    )?;
120
121    kumo_mod.set(
122        "on",
123        lua.create_function(move |lua, (name, func): (String, Function)| {
124            let decorated_name = decorate_callback_name(&name);
125
126            if let Ok(current_event) = lua.globals().get::<String>("_KUMO_CURRENT_EVENT") {
127                if current_event != "main" {
128                    return Err(mlua::Error::external(format!(
129                        "Attempting to register an event handler via \
130                    `kumo.on('{name}', ...)` from within the event handler \
131                    '{current_event}'. You must move your event handler registration \
132                    so that it is setup directly when the policy is loaded \
133                    in order for it to consistently trigger and handle events."
134                    )));
135                }
136            }
137
138            register_event_caller(lua, &name)?;
139
140            if config::does_callback_allow_multiple(&name) {
141                let tbl: Value = lua.named_registry_value(&decorated_name)?;
142                return match tbl {
143                    Value::Nil => {
144                        let tbl = lua.create_table()?;
145                        tbl.set(1, func)?;
146                        lua.set_named_registry_value(&decorated_name, tbl)?;
147                        Ok(())
148                    }
149                    Value::Table(tbl) => {
150                        let len = tbl.raw_len();
151                        tbl.set(len + 1, func)?;
152                        Ok(())
153                    }
154                    _ => Err(mlua::Error::external(format!(
155                        "registry key for {decorated_name} has invalid type",
156                    ))),
157                };
158            }
159
160            let existing: Value = lua.named_registry_value(&decorated_name)?;
161            match existing {
162                Value::Nil => {}
163                Value::Function(func) => {
164                    let info = func.info();
165                    let src = info.source.unwrap_or_else(|| "?".into());
166                    let line = info.line_defined.unwrap_or(0);
167                    return Err(mlua::Error::external(format!(
168                        "{name} event already has a handler defined at {src}:{line}"
169                    )));
170                }
171                _ => {
172                    return Err(mlua::Error::external(format!(
173                        "{name} event already has a handler"
174                    )));
175                }
176            }
177
178            lua.set_named_registry_value(&decorated_name, func)?;
179            Ok(())
180        })?,
181    )?;
182
183    kumo_mod.set(
184        "set_diagnostic_log_filter",
185        lua.create_function(move |_, filter: String| {
186            diagnostic_logging::set_diagnostic_log_filter(&filter).map_err(any_err)
187        })?,
188    )?;
189
190    fn variadic_to_string(args: Variadic<Value>) -> String {
191        let mut output = String::new();
192        for (idx, item) in args.into_iter().enumerate() {
193            if idx > 0 {
194                output.push(' ');
195            }
196
197            match item {
198                Value::String(s) => {
199                    let bytes = s.as_bytes();
200                    for (start, end, c) in bytes.char_indices() {
201                        if c == std::char::REPLACEMENT_CHARACTER {
202                            let c_slice = &bytes[start..end];
203                            for &b in c_slice.iter() {
204                                output.push_str(&format!("\\x{b:02X}"));
205                            }
206                        } else {
207                            output.push(c);
208                        }
209                    }
210                }
211                item => match item.to_string() {
212                    Ok(s) => output.push_str(&s),
213                    Err(_) => output.push_str(&format!("{item:?}")),
214                },
215            }
216        }
217        output
218    }
219
220    fn get_caller(lua: &Lua) -> String {
221        match lua.inspect_stack(1, |info| {
222            let source = info.source();
223            let file_name = source
224                .short_src
225                .as_ref()
226                .map(|b| b.to_string())
227                .unwrap_or_else(String::new);
228            // Lua returns the somewhat obnoxious `[string "source.lua"]`
229            // Let's fix that up to be a bit nicer
230            let file_name = match file_name.strip_prefix("[string \"") {
231                Some(name) => name.strip_suffix("\"]").unwrap_or(name),
232                None => &file_name,
233            };
234
235            format!("{file_name}:{}", info.current_line().unwrap_or(0))
236        }) {
237            Some(info) => info,
238            None => "?".to_string(),
239        }
240    }
241
242    kumo_mod.set(
243        "log_error",
244        lua.create_function(move |lua, args: Variadic<Value>| {
245            if tracing::event_enabled!(target: "lua", tracing::Level::ERROR) {
246                let src = get_caller(lua);
247                tracing::error!(target: "lua", "{src}: {}", variadic_to_string(args));
248            }
249            Ok(())
250        })?,
251    )?;
252    kumo_mod.set(
253        "log_info",
254        lua.create_function(move |lua, args: Variadic<Value>| {
255            if tracing::event_enabled!(target: "lua", tracing::Level::INFO) {
256                let src = get_caller(lua);
257                tracing::info!(target: "lua", "{src}: {}", variadic_to_string(args));
258            }
259            Ok(())
260        })?,
261    )?;
262    kumo_mod.set(
263        "log_warn",
264        lua.create_function(move |lua, args: Variadic<Value>| {
265            if tracing::event_enabled!(target: "lua", tracing::Level::WARN) {
266                let src = get_caller(lua);
267                tracing::warn!(target: "lua", "{src}: {}", variadic_to_string(args));
268            }
269            Ok(())
270        })?,
271    )?;
272    kumo_mod.set(
273        "log_debug",
274        lua.create_function(move |lua, args: Variadic<Value>| {
275            if tracing::event_enabled!(target: "lua", tracing::Level::DEBUG) {
276                let src = get_caller(lua);
277                tracing::debug!(target: "lua", "{src}: {}", variadic_to_string(args));
278            }
279            Ok(())
280        })?,
281    )?;
282
283    kumo_mod.set(
284        "set_max_spare_lua_contexts",
285        lua.create_function(move |_, limit: usize| {
286            config::set_max_spare(limit);
287            Ok(())
288        })?,
289    )?;
290
291    kumo_mod.set(
292        "set_max_lua_context_use_count",
293        lua.create_function(move |_, limit: usize| {
294            config::set_max_use(limit);
295            Ok(())
296        })?,
297    )?;
298
299    kumo_mod.set(
300        "set_max_lua_context_age",
301        lua.create_function(move |_, limit: usize| {
302            config::set_max_age(limit);
303            Ok(())
304        })?,
305    )?;
306
307    kumo_mod.set(
308        "set_lua_gc_on_put",
309        lua.create_function(move |_, enable: u8| {
310            config::set_gc_on_put(enable);
311            Ok(())
312        })?,
313    )?;
314
315    kumo_mod.set(
316        "set_lruttl_cache_capacity",
317        lua.create_function(move |_, (name, capacity): (String, usize)| {
318            if lruttl::set_cache_capacity(&name, capacity) {
319                Ok(())
320            } else {
321                Err(mlua::Error::external(format!(
322                    "could not set capacity for cache {name} \
323                    as that is not a pre-defined lruttl cache name"
324                )))
325            }
326        })?,
327    )?;
328
329    kumo_mod.set(
330        "set_config_monitor_globs",
331        lua.create_function(move |_, globs: Vec<String>| {
332            config::epoch::set_globs(globs).map_err(any_err)?;
333            Ok(())
334        })?,
335    )?;
336    kumo_mod.set(
337        "eval_config_monitor_globs",
338        lua.create_async_function(|_, _: ()| async move {
339            config::epoch::eval_globs().await.map_err(any_err)
340        })?,
341    )?;
342    kumo_mod.set(
343        "bump_config_epoch",
344        lua.create_function(move |_, _: ()| {
345            config::epoch::bump_current_epoch();
346            Ok(())
347        })?,
348    )?;
349
350    kumo_mod.set(
351        "available_parallelism",
352        lua.create_function(move |_, _: ()| available_parallelism().map_err(any_err))?,
353    )?;
354
355    kumo_mod.set(
356        "set_memory_hard_limit",
357        lua.create_function(move |_, limit: usize| {
358            kumo_server_memory::set_hard_limit(limit);
359            Ok(())
360        })?,
361    )?;
362
363    kumo_mod.set(
364        "set_memory_low_thresh",
365        lua.create_function(move |_, limit: usize| {
366            kumo_server_memory::set_low_memory_thresh(limit);
367            Ok(())
368        })?,
369    )?;
370
371    kumo_mod.set(
372        "set_memory_soft_limit",
373        lua.create_function(move |_, limit: usize| {
374            kumo_server_memory::set_soft_limit(limit);
375            Ok(())
376        })?,
377    )?;
378
379    kumo_mod.set(
380        "get_memory_hard_limit",
381        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_hard_limit()))?,
382    )?;
383
384    kumo_mod.set(
385        "get_memory_soft_limit",
386        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_soft_limit()))?,
387    )?;
388
389    kumo_mod.set(
390        "get_memory_low_thresh",
391        lua.create_function(move |_, _: ()| Ok(kumo_server_memory::get_low_memory_thresh()))?,
392    )?;
393
394    kumo_mod.set(
395        "configure_redis_throttles",
396        lua.create_async_function(|lua, params: Value| async move {
397            let key: RedisConnKey = from_lua_value(&lua, params)?;
398            let conn = key.open().map_err(any_err)?;
399            conn.ping().await.map_err(any_err)?;
400            throttle::use_redis(conn).await.map_err(any_err)
401        })?,
402    )?;
403
404    kumo_mod.set(
405        "traceback",
406        lua.create_function(move |lua: &Lua, level: usize| {
407            #[derive(Debug, Serialize)]
408            struct Frame {
409                event: String,
410                name: Option<String>,
411                name_what: Option<String>,
412                source: Option<String>,
413                short_src: Option<String>,
414                line_defined: Option<usize>,
415                last_line_defined: Option<usize>,
416                what: &'static str,
417                curr_line: Option<usize>,
418                is_tail_call: bool,
419            }
420
421            let mut frames = vec![];
422            for n in level.. {
423                match lua.inspect_stack(n, |info| {
424                    let source = info.source();
425                    let names = info.names();
426                    Frame {
427                        curr_line: info.current_line(),
428                        is_tail_call: info.is_tail_call(),
429                        event: format!("{:?}", info.event()),
430                        last_line_defined: source.last_line_defined,
431                        line_defined: source.line_defined,
432                        name: names.name.as_ref().map(|b| b.to_string()),
433                        name_what: names.name_what.as_ref().map(|b| b.to_string()),
434                        source: source.source.as_ref().map(|b| b.to_string()),
435                        short_src: source.short_src.as_ref().map(|b| b.to_string()),
436                        what: source.what,
437                    }
438                }) {
439                    Some(frame) => {
440                        frames.push(frame);
441                    }
442                    None => break,
443                }
444            }
445
446            lua.to_value(&frames)
447        })?,
448    )?;
449
450    // TODO: options like restarting on error, delay between
451    // restarts and so on
452    #[derive(Deserialize, Debug)]
453    struct TaskParams {
454        event_name: String,
455        args: Vec<serde_json::Value>,
456    }
457
458    impl TaskParams {
459        async fn run(&self) -> anyhow::Result<()> {
460            let mut config = load_config().await?;
461
462            let sig = CallbackSignature::<Value, ()>::new(self.event_name.to_string());
463
464            config
465                .convert_args_and_call_callback(&sig, &self.args)
466                .await?;
467
468            config.put();
469
470            Ok(())
471        }
472    }
473
474    kumo_mod.set(
475        "spawn_task",
476        lua.create_function(|lua, params: Value| {
477            let params: TaskParams = lua.from_value(params)?;
478
479            if !config::is_validating() {
480                std::thread::Builder::new()
481                    .name(format!("spawned-task-{}", params.event_name))
482                    .spawn(move || {
483                        let runtime = tokio::runtime::Builder::new_current_thread()
484                            .enable_io()
485                            .enable_time()
486                            .on_thread_park(kumo_server_memory::purge_thread_cache)
487                            .build()
488                            .unwrap();
489                        let event_name = params.event_name.clone();
490
491                        let result = runtime.block_on(async move { params.run().await });
492                        if let Err(err) = result {
493                            tracing::error!("Error while dispatching {event_name}: {err:#}");
494                        }
495                    })?;
496            }
497
498            Ok(())
499        })?,
500    )?;
501
502    kumo_mod.set(
503        "spawn_thread_pool",
504        lua.create_function(|lua, params: Value| {
505            #[derive(Deserialize, Debug)]
506            struct ThreadPoolParams {
507                name: String,
508                num_threads: usize,
509            }
510
511            let params: ThreadPoolParams = lua.from_value(params)?;
512            let num_threads = AtomicUsize::new(params.num_threads);
513
514            if !config::is_validating() {
515                // Create the runtime. We don't need to hold on
516                // to it here, as it will be kept alive in the
517                // runtimes map in that crate
518                let _runtime = kumo_server_runtime::Runtime::new(
519                    &params.name,
520                    |_| params.num_threads,
521                    &num_threads,
522                )
523                .map_err(any_err)?;
524            }
525
526            Ok(())
527        })?,
528    )?;
529
530    kumo_mod.set(
531        "validation_failed",
532        lua.create_function(|_, ()| {
533            config::set_validation_failed();
534            Ok(())
535        })?,
536    )?;
537
538    kumo_mod.set(
539        "enable_memory_callstack_tracking",
540        lua.create_function(|_, enable: bool| {
541            kumo_server_memory::set_tracking_callstacks(enable);
542            Ok(())
543        })?,
544    )?;
545
546    // This function is intended for debugging and testing purposes only.
547    // It is potentially very expensive on a production system with many
548    // thousands of queues.
549    kumo_mod.set(
550        "prometheus_metrics",
551        lua.create_async_function(|lua, ()| async move {
552            use tokio_stream::StreamExt;
553            let mut json_text = String::new();
554            let mut stream = kumo_prometheus::registry::Registry::stream_json();
555            while let Some(text) = stream.next().await {
556                json_text.push_str(&text);
557            }
558            let value: serde_json::Value = serde_json::from_str(&json_text).map_err(any_err)?;
559            lua.to_value_with(&value, serialize_options())
560        })?,
561    )?;
562
563    Ok(())
564}