kumo_server_common/
lib.rs

1use config::{
2    any_err, decorate_callback_name, from_lua_value, get_or_create_module, load_config,
3    CallbackSignature,
4};
5use kumo_server_runtime::available_parallelism;
6use mlua::{Function, Lua, LuaSerdeExt, Value, Variadic};
7use mod_redis::RedisConnKey;
8use serde::{Deserialize, Serialize};
9use std::sync::atomic::AtomicUsize;
10
11pub mod config_handle;
12pub mod diagnostic_logging;
13pub mod disk_space;
14pub mod http_server;
15pub mod nodeid;
16pub mod panic;
17pub mod start;
18pub mod tls_helpers;
19
20pub fn register(lua: &Lua) -> anyhow::Result<()> {
21    for func in [
22        mod_redis::register,
23        data_loader::register,
24        mod_digest::register,
25        mod_encode::register,
26        cidr_map::register,
27        domain_map::register,
28        mod_amqp::register,
29        mod_filesystem::register,
30        mod_http::register,
31        mod_regex::register,
32        mod_serde::register,
33        mod_sqlite::register,
34        mod_string::register,
35        mod_time::register,
36        mod_dns_resolver::register,
37        mod_kafka::register,
38        mod_memoize::register,
39        mod_uuid::register,
40        kumo_api_types::shaping::register,
41        regex_set_map::register,
42    ] {
43        func(lua)?;
44    }
45
46    let kumo_mod = get_or_create_module(lua, "kumo")?;
47
48    fn event_registrar_name(name: &str) -> String {
49        format!("kumomta-event-registrars-{name}")
50    }
51
52    // Record the call stack of the code calling kumo.on so that
53    // kumo.get_event_registrars can retrieve it later
54    fn register_event_caller(lua: &Lua, name: &str) -> mlua::Result<()> {
55        let decorated_name = event_registrar_name(name);
56        let mut call_stack = vec![];
57        for n in 1.. {
58            match lua.inspect_stack(n) {
59                Some(info) => {
60                    let source = info.source();
61                    call_stack.push(format!(
62                        "{}:{}",
63                        source
64                            .short_src
65                            .as_ref()
66                            .map(|b| b.to_string())
67                            .unwrap_or_else(String::new),
68                        info.curr_line()
69                    ));
70                }
71                None => break,
72            }
73        }
74
75        let tbl: Value = lua.named_registry_value(&decorated_name)?;
76        match tbl {
77            Value::Nil => {
78                let tbl = lua.create_table()?;
79                tbl.set(1, call_stack)?;
80                lua.set_named_registry_value(&decorated_name, tbl)?;
81                Ok(())
82            }
83            Value::Table(tbl) => {
84                let len = tbl.raw_len();
85                tbl.set(len + 1, call_stack)?;
86                Ok(())
87            }
88            _ => Err(mlua::Error::external(format!(
89                "registry key for {decorated_name} has invalid type",
90            ))),
91        }
92    }
93
94    // Returns the list of call-stacks of the code that registered
95    // for a specific named event
96    kumo_mod.set(
97        "get_event_registrars",
98        lua.create_function(move |lua, name: String| {
99            let decorated_name = event_registrar_name(&name);
100            let value: Value = lua.named_registry_value(&decorated_name)?;
101            Ok(value)
102        })?,
103    )?;
104
105    kumo_mod.set(
106        "on",
107        lua.create_function(move |lua, (name, func): (String, Function)| {
108            let decorated_name = decorate_callback_name(&name);
109
110            if let Ok(current_event) = lua.globals().get::<String>("_KUMO_CURRENT_EVENT") {
111                if current_event != "main" {
112                    return Err(mlua::Error::external(format!(
113                        "Attempting to register an event handler via \
114                    `kumo.on('{name}', ...)` from within the event handler \
115                    '{current_event}'. You must move your event handler registration \
116                    so that it is setup directly when the policy is loaded \
117                    in order for it to consistently trigger and handle events."
118                    )));
119                }
120            }
121
122            register_event_caller(lua, &name)?;
123
124            if config::does_callback_allow_multiple(&name) {
125                let tbl: Value = lua.named_registry_value(&decorated_name)?;
126                return match tbl {
127                    Value::Nil => {
128                        let tbl = lua.create_table()?;
129                        tbl.set(1, func)?;
130                        lua.set_named_registry_value(&decorated_name, tbl)?;
131                        Ok(())
132                    }
133                    Value::Table(tbl) => {
134                        let len = tbl.raw_len();
135                        tbl.set(len + 1, func)?;
136                        Ok(())
137                    }
138                    _ => Err(mlua::Error::external(format!(
139                        "registry key for {decorated_name} has invalid type",
140                    ))),
141                };
142            }
143
144            let existing: Value = lua.named_registry_value(&decorated_name)?;
145            match existing {
146                Value::Nil => {}
147                Value::Function(func) => {
148                    let info = func.info();
149                    let src = info.source.unwrap_or_else(|| "?".into());
150                    let line = info.line_defined.unwrap_or(0);
151                    return Err(mlua::Error::external(format!(
152                        "{name} event already has a handler defined at {src}:{line}"
153                    )));
154                }
155                _ => {
156                    return Err(mlua::Error::external(format!(
157                        "{name} event already has a handler"
158                    )));
159                }
160            }
161
162            lua.set_named_registry_value(&decorated_name, func)?;
163            Ok(())
164        })?,
165    )?;
166
167    kumo_mod.set(
168        "set_diagnostic_log_filter",
169        lua.create_function(move |_, filter: String| {
170            diagnostic_logging::set_diagnostic_log_filter(&filter).map_err(any_err)
171        })?,
172    )?;
173
174    fn variadic_to_string(args: Variadic<Value>) -> String {
175        let mut output = String::new();
176        for (idx, item) in args.into_iter().enumerate() {
177            if idx > 0 {
178                output.push(' ');
179            }
180
181            match item {
182                Value::String(s) => match s.to_str() {
183                    Ok(s) => output.push_str(&s),
184                    Err(_) => {
185                        let item = s.to_string_lossy();
186                        output.push_str(&item);
187                    }
188                },
189                item => match item.to_string() {
190                    Ok(s) => output.push_str(&s),
191                    Err(_) => output.push_str(&format!("{item:?}")),
192                },
193            }
194        }
195        output
196    }
197
198    fn get_caller(lua: &Lua) -> String {
199        match lua.inspect_stack(1) {
200            Some(info) => {
201                let source = info.source();
202                let file_name = source
203                    .short_src
204                    .as_ref()
205                    .map(|b| b.to_string())
206                    .unwrap_or_else(String::new);
207                // Lua returns the somewhat obnoxious `[string "source.lua"]`
208                // Let's fix that up to be a bit nicer
209                let file_name = match file_name.strip_prefix("[string \"") {
210                    Some(name) => name.strip_suffix("\"]").unwrap_or(name),
211                    None => &file_name,
212                };
213
214                format!("{file_name}:{}", info.curr_line())
215            }
216            None => "?".to_string(),
217        }
218    }
219
220    kumo_mod.set(
221        "log_error",
222        lua.create_function(move |lua, args: Variadic<Value>| {
223            if tracing::event_enabled!(target: "lua", tracing::Level::ERROR) {
224                let src = get_caller(lua);
225                tracing::error!(target: "lua", "{src}: {}", variadic_to_string(args));
226            }
227            Ok(())
228        })?,
229    )?;
230    kumo_mod.set(
231        "log_info",
232        lua.create_function(move |lua, args: Variadic<Value>| {
233            if tracing::event_enabled!(target: "lua", tracing::Level::INFO) {
234                let src = get_caller(lua);
235                tracing::info!(target: "lua", "{src}: {}", variadic_to_string(args));
236            }
237            Ok(())
238        })?,
239    )?;
240    kumo_mod.set(
241        "log_warn",
242        lua.create_function(move |lua, args: Variadic<Value>| {
243            if tracing::event_enabled!(target: "lua", tracing::Level::WARN) {
244                let src = get_caller(lua);
245                tracing::warn!(target: "lua", "{src}: {}", variadic_to_string(args));
246            }
247            Ok(())
248        })?,
249    )?;
250    kumo_mod.set(
251        "log_debug",
252        lua.create_function(move |lua, args: Variadic<Value>| {
253            if tracing::event_enabled!(target: "lua", tracing::Level::DEBUG) {
254                let src = get_caller(lua);
255                tracing::debug!(target: "lua", "{src}: {}", variadic_to_string(args));
256            }
257            Ok(())
258        })?,
259    )?;
260
261    kumo_mod.set(
262        "set_max_spare_lua_contexts",
263        lua.create_function(move |_, limit: usize| {
264            config::set_max_spare(limit);
265            Ok(())
266        })?,
267    )?;
268
269    kumo_mod.set(
270        "set_max_lua_context_use_count",
271        lua.create_function(move |_, limit: usize| {
272            config::set_max_use(limit);
273            Ok(())
274        })?,
275    )?;
276
277    kumo_mod.set(
278        "set_max_lua_context_age",
279        lua.create_function(move |_, limit: usize| {
280            config::set_max_age(limit);
281            Ok(())
282        })?,
283    )?;
284
285    kumo_mod.set(
286        "set_lua_gc_on_put",
287        lua.create_function(move |_, enable: u8| {
288            config::set_gc_on_put(enable);
289            Ok(())
290        })?,
291    )?;
292
293    kumo_mod.set(
294        "set_lruttl_cache_capacity",
295        lua.create_function(move |_, (name, capacity): (String, usize)| {
296            if lruttl::set_cache_capacity(&name, capacity) {
297                Ok(())
298            } else {
299                Err(mlua::Error::external(format!(
300                    "could not set capacity for cache {name} \
301                    as that is not a pre-defined lruttl cache name"
302                )))
303            }
304        })?,
305    )?;
306
307    kumo_mod.set(
308        "set_config_monitor_globs",
309        lua.create_function(move |_, globs: Vec<String>| {
310            config::epoch::set_globs(globs).map_err(any_err)?;
311            Ok(())
312        })?,
313    )?;
314    kumo_mod.set(
315        "eval_config_monitor_globs",
316        lua.create_async_function(|_, _: ()| async move {
317            config::epoch::eval_globs().await.map_err(any_err)
318        })?,
319    )?;
320    kumo_mod.set(
321        "bump_config_epoch",
322        lua.create_function(move |_, _: ()| {
323            config::epoch::bump_current_epoch();
324            Ok(())
325        })?,
326    )?;
327
328    kumo_mod.set(
329        "available_parallelism",
330        lua.create_function(move |_, _: ()| available_parallelism().map_err(any_err))?,
331    )?;
332
333    kumo_mod.set(
334        "set_memory_hard_limit",
335        lua.create_function(move |_, limit: usize| {
336            kumo_server_memory::set_hard_limit(limit);
337            Ok(())
338        })?,
339    )?;
340
341    kumo_mod.set(
342        "set_memory_low_thresh",
343        lua.create_function(move |_, limit: usize| {
344            kumo_server_memory::set_low_memory_thresh(limit);
345            Ok(())
346        })?,
347    )?;
348
349    kumo_mod.set(
350        "set_memory_soft_limit",
351        lua.create_function(move |_, limit: usize| {
352            kumo_server_memory::set_soft_limit(limit);
353            Ok(())
354        })?,
355    )?;
356
357    kumo_mod.set(
358        "configure_redis_throttles",
359        lua.create_async_function(|lua, params: Value| async move {
360            let key: RedisConnKey = from_lua_value(&lua, params)?;
361            let conn = key.open().map_err(any_err)?;
362            conn.ping().await.map_err(any_err)?;
363            throttle::use_redis(conn).await.map_err(any_err)
364        })?,
365    )?;
366
367    kumo_mod.set(
368        "traceback",
369        lua.create_function(move |lua: &Lua, level: usize| {
370            #[derive(Debug, Serialize)]
371            struct Frame {
372                event: String,
373                name: Option<String>,
374                name_what: Option<String>,
375                source: Option<String>,
376                short_src: Option<String>,
377                line_defined: Option<usize>,
378                last_line_defined: Option<usize>,
379                what: &'static str,
380                curr_line: i32,
381                is_tail_call: bool,
382            }
383
384            let mut frames = vec![];
385            for n in level.. {
386                match lua.inspect_stack(n) {
387                    Some(info) => {
388                        let source = info.source();
389                        let names = info.names();
390                        frames.push(Frame {
391                            curr_line: info.curr_line(),
392                            is_tail_call: info.is_tail_call(),
393                            event: format!("{:?}", info.event()),
394                            last_line_defined: source.last_line_defined,
395                            line_defined: source.line_defined,
396                            name: names.name.as_ref().map(|b| b.to_string()),
397                            name_what: names.name_what.as_ref().map(|b| b.to_string()),
398                            source: source.source.as_ref().map(|b| b.to_string()),
399                            short_src: source.short_src.as_ref().map(|b| b.to_string()),
400                            what: source.what,
401                        });
402                    }
403                    None => break,
404                }
405            }
406
407            lua.to_value(&frames)
408        })?,
409    )?;
410
411    // TODO: options like restarting on error, delay between
412    // restarts and so on
413    #[derive(Deserialize, Debug)]
414    struct TaskParams {
415        event_name: String,
416        args: Vec<serde_json::Value>,
417    }
418
419    impl TaskParams {
420        async fn run(&self) -> anyhow::Result<()> {
421            let mut config = load_config().await?;
422
423            let sig = CallbackSignature::<Value, ()>::new(self.event_name.to_string());
424
425            config
426                .convert_args_and_call_callback(&sig, &self.args)
427                .await?;
428
429            config.put();
430
431            Ok(())
432        }
433    }
434
435    kumo_mod.set(
436        "spawn_task",
437        lua.create_function(|lua, params: Value| {
438            let params: TaskParams = lua.from_value(params)?;
439
440            if !config::is_validating() {
441                std::thread::Builder::new()
442                    .name(format!("spawned-task-{}", params.event_name))
443                    .spawn(move || {
444                        let runtime = tokio::runtime::Builder::new_current_thread()
445                            .enable_io()
446                            .enable_time()
447                            .on_thread_park(kumo_server_memory::purge_thread_cache)
448                            .build()
449                            .unwrap();
450                        let event_name = params.event_name.clone();
451
452                        let result = runtime.block_on(async move { params.run().await });
453                        if let Err(err) = result {
454                            tracing::error!("Error while dispatching {event_name}: {err:#}");
455                        }
456                    })?;
457            }
458
459            Ok(())
460        })?,
461    )?;
462
463    kumo_mod.set(
464        "spawn_thread_pool",
465        lua.create_function(|lua, params: Value| {
466            #[derive(Deserialize, Debug)]
467            struct ThreadPoolParams {
468                name: String,
469                num_threads: usize,
470            }
471
472            let params: ThreadPoolParams = lua.from_value(params)?;
473            let num_threads = AtomicUsize::new(params.num_threads);
474
475            if !config::is_validating() {
476                // Create the runtime. We don't need to hold on
477                // to it here, as it will be kept alive in the
478                // runtimes map in that crate
479                let _runtime = kumo_server_runtime::Runtime::new(
480                    &params.name,
481                    |_| params.num_threads,
482                    &num_threads,
483                )
484                .map_err(any_err)?;
485            }
486
487            Ok(())
488        })?,
489    )?;
490
491    kumo_mod.set(
492        "validation_failed",
493        lua.create_function(|_, ()| {
494            config::set_validation_failed();
495            Ok(())
496        })?,
497    )?;
498
499    kumo_mod.set(
500        "enable_memory_callstack_tracking",
501        lua.create_function(|_, enable: bool| {
502            kumo_server_memory::set_tracking_callstacks(enable);
503            Ok(())
504        })?,
505    )?;
506
507    Ok(())
508}