config/
lib.rs

1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use mlua::{
6    FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
7    UserData, UserDataMethods, Value,
8};
9use parking_lot::FairMutex as Mutex;
10pub use paste;
11use prometheus::{CounterVec, HistogramTimer, HistogramVec};
12use serde::Serialize;
13use std::borrow::Cow;
14use std::collections::HashSet;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{LazyLock, Once};
18use std::time::Instant;
19
20pub mod epoch;
21mod pool;
22
23static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
24static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
25static LUA_LOAD_COUNT: LazyLock<metrics::Counter> = LazyLock::new(|| {
26    metrics::describe_counter!(
27        "lua_load_count",
28        "how many times the policy lua script has been \
29         loaded into a new context"
30    );
31    metrics::counter!("lua_load_count")
32});
33static LUA_COUNT: LazyLock<metrics::Gauge> = LazyLock::new(|| {
34    metrics::describe_gauge!("lua_count", "the number of lua contexts currently alive");
35    metrics::gauge!("lua_count")
36});
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38    LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42static LATENCY_HIST: LazyLock<HistogramVec> = LazyLock::new(|| {
43    prometheus::register_histogram_vec!(
44        "lua_event_latency",
45        "how long a given lua event callback took",
46        &["event"]
47    )
48    .unwrap()
49});
50static EVENT_STARTED_COUNT: LazyLock<CounterVec> = LazyLock::new(|| {
51    prometheus::register_counter_vec!(
52        "lua_event_started",
53        "Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events",
54        &["event"]
55    )
56    .unwrap()
57});
58
59pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
60
61fn latency_timer(label: &str) -> HistogramTimer {
62    EVENT_STARTED_COUNT
63        .get_metric_with_label_values(&[label])
64        .expect("to get counter")
65        .inc();
66    LATENCY_HIST
67        .get_metric_with_label_values(&[label])
68        .expect("to get histo")
69        .start_timer()
70}
71
72#[derive(Debug)]
73struct LuaConfigInner {
74    lua: Lua,
75    created: Instant,
76    use_count: usize,
77    epoch: ConfigEpoch,
78}
79
80impl Drop for LuaConfigInner {
81    fn drop(&mut self) {
82        LUA_COUNT.decrement(1.);
83    }
84}
85
86#[derive(Debug)]
87pub struct LuaConfig {
88    inner: Option<LuaConfigInner>,
89}
90
91pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
92    POLICY_FILE.lock().replace(path);
93    let config = load_config().await?;
94    config.put();
95    Ok(())
96}
97
98fn get_policy_path() -> Option<PathBuf> {
99    POLICY_FILE.lock().clone()
100}
101
102fn get_funcs() -> Vec<RegisterFunc> {
103    FUNCS.lock().clone()
104}
105pub fn is_validating() -> bool {
106    VALIDATE_ONLY.load(Ordering::Relaxed)
107}
108
109pub fn validation_failed() -> bool {
110    VALIDATION_FAILED.load(Ordering::Relaxed)
111}
112
113pub fn set_validation_failed() {
114    VALIDATION_FAILED.store(true, Ordering::Relaxed)
115}
116
117pub async fn load_config() -> anyhow::Result<LuaConfig> {
118    if let Some(pool) = pool_get() {
119        return Ok(pool);
120    }
121
122    LUA_LOAD_COUNT.increment(1);
123    let lua = Lua::new();
124    let created = Instant::now();
125    let epoch = get_current_epoch();
126
127    {
128        let globals = lua.globals();
129
130        if is_validating() {
131            globals.set("_VALIDATING_CONFIG", true)?;
132        }
133
134        let package: Table = globals.get("package")?;
135        let package_path: String = package.get("path")?;
136        let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
137
138        fn prefix_path(array: &mut Vec<String>, path: &str) {
139            array.insert(0, format!("{}/?.lua", path));
140            array.insert(1, format!("{}/?/init.lua", path));
141        }
142
143        prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
144        prefix_path(&mut path_array, "/opt/kumomta/share");
145
146        #[cfg(debug_assertions)]
147        prefix_path(&mut path_array, "assets");
148
149        package.set("path", path_array.join(";"))?;
150    }
151
152    register_declared_events();
153
154    for func in get_funcs() {
155        (func)(&lua)?;
156    }
157
158    if let Some(policy) = get_policy_path() {
159        let code = tokio::fs::read_to_string(&policy)
160            .await
161            .with_context(|| format!("reading policy file {policy:?}"))?;
162
163        let func = {
164            let chunk = lua.load(&code);
165            let chunk = chunk.set_name(policy.to_string_lossy());
166            chunk.into_function()?
167        };
168
169        let _timer = latency_timer("context-creation");
170        func.call_async::<()>(()).await?;
171    }
172    LUA_COUNT.increment(1.);
173
174    Ok(LuaConfig {
175        inner: Some(LuaConfigInner {
176            lua,
177            created,
178            use_count: 1,
179            epoch,
180        }),
181    })
182}
183
184pub fn register(func: RegisterFunc) {
185    FUNCS.lock().push(func);
186}
187
188impl LuaConfig {
189    fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
190        self.inner
191            .as_mut()
192            .unwrap()
193            .lua
194            .globals()
195            .set("_KUMO_CURRENT_EVENT", name.to_string())
196    }
197
198    /// Intended to be used together with kumo.spawn_task
199    pub async fn convert_args_and_call_callback<A: Serialize>(
200        &mut self,
201        sig: &CallbackSignature<Value, ()>,
202        args: A,
203    ) -> anyhow::Result<()> {
204        let lua = self.inner.as_mut().unwrap();
205        let args = lua.lua.to_value(&args)?;
206
207        let name = sig.name();
208        let decorated_name = sig.decorated_name();
209
210        match lua
211            .lua
212            .named_registry_value::<mlua::Function>(&decorated_name)
213        {
214            Ok(func) => {
215                let _timer = latency_timer(name);
216                Ok(func.call_async(args).await?)
217            }
218            _ => anyhow::bail!("{name} has not been registered"),
219        }
220    }
221
222    /// Explicitly put the config object back into its containing pool.
223    /// Ideally we'd do this automatically when the object is dropped,
224    /// but lua's garbage collection makes this problematic:
225    /// if a future whose graph contains an async lua call within
226    /// this config object is cancelled (eg: simply stopped without
227    /// calling it again), and the config object is not explicitly garbage
228    /// collected, any futures and data owned by any dependencies of
229    /// the cancelled future remain alive until the next gc run,
230    /// which can cause things like async locks and semaphores to
231    /// have a lifetime extended by the maximum age of the lua context.
232    ///
233    /// The combat this, consumers of LuaConfig should explicitly
234    /// call `config.put()` after successfully using the config
235    /// object.
236    ///
237    /// Or framing it another way: consumers must not call `config.put()`
238    /// if a transitive dep might have been cancelled.
239    pub fn put(mut self) {
240        if let Some(inner) = self.inner.take() {
241            pool_put(inner);
242        }
243    }
244
245    pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
246        &mut self,
247        sig: &CallbackSignature<A, R>,
248        args: A,
249    ) -> anyhow::Result<R> {
250        let name = sig.name();
251        self.set_current_event(name)?;
252        let lua = self.inner.as_mut().unwrap();
253        async_call_callback(&lua.lua, sig, args).await
254    }
255
256    pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
257        &mut self,
258        sig: &CallbackSignature<A, R>,
259        args: A,
260    ) -> anyhow::Result<R> {
261        let name = sig.name();
262        self.set_current_event(name)?;
263        let lua = self.inner.as_mut().unwrap();
264        async_call_callback_non_default(&lua.lua, sig, args).await
265    }
266
267    pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
268        &mut self,
269        sig: &CallbackSignature<A, Option<R>>,
270        args: A,
271    ) -> anyhow::Result<Option<R>> {
272        let name = sig.name();
273        let decorated_name = sig.decorated_name();
274        self.set_current_event(name)?;
275        let lua = self.inner.as_mut().unwrap();
276
277        match lua
278            .lua
279            .named_registry_value::<mlua::Value>(&decorated_name)?
280        {
281            Value::Table(tbl) => {
282                for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
283                    let func = func?;
284                    let _timer = latency_timer(name);
285                    let result: mlua::MultiValue = func.call_async(args.clone()).await?;
286                    if result.is_empty() {
287                        // Continue with other handlers
288                        continue;
289                    }
290                    let result = R::from_lua_multi(result, &lua.lua)?;
291                    return Ok(Some(result));
292                }
293                Ok(None)
294            }
295            Value::Function(func) => {
296                sig.raise_error_if_allow_multiple()?;
297                let _timer = latency_timer(name);
298                let value: Value = func.call_async(args.clone()).await?;
299
300                match value {
301                    Value::Nil => Ok(None),
302                    value => {
303                        let result = R::from_lua(value, &lua.lua)?;
304                        Ok(Some(result))
305                    }
306                }
307            }
308            _ => Ok(None),
309        }
310    }
311
312    pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
313        Ok(self
314            .inner
315            .as_mut()
316            .unwrap()
317            .lua
318            .remove_registry_value(value)?)
319    }
320
321    /// Call a constructor registered via `on`. Returns a registry key that can be
322    /// used to reference the returned value again later on this same Lua instance
323    pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
324        &mut self,
325        sig: &CallbackSignature<A, Value>,
326        args: A,
327    ) -> anyhow::Result<RegistryKey> {
328        let name = sig.name();
329        anyhow::ensure!(
330            !sig.allow_multiple(),
331            "ctor event signature for {name} is defined as allow_multiple, which is not supported"
332        );
333
334        let decorated_name = sig.decorated_name();
335        self.set_current_event(name)?;
336
337        let inner = self.inner.as_mut().unwrap();
338
339        let func = inner
340            .lua
341            .named_registry_value::<mlua::Function>(&decorated_name)?;
342
343        let _timer = latency_timer(name);
344        let value: Value = func.call_async(args.clone()).await?;
345        drop(func);
346
347        Ok(inner.lua.create_registry_value(value)?)
348    }
349
350    /// Operate on an object/value that was previously constructed via
351    /// async_call_ctor.
352    pub async fn with_registry_value<F, R, FUT>(
353        &mut self,
354        value: &RegistryKey,
355        func: F,
356    ) -> anyhow::Result<R>
357    where
358        R: FromLuaMulti,
359        F: FnOnce(Value) -> anyhow::Result<FUT>,
360        FUT: std::future::Future<Output = anyhow::Result<R>>,
361    {
362        let inner = self.inner.as_mut().unwrap();
363        let value = inner.lua.registry_value(value)?;
364        let future = (func)(value)?;
365        future.await
366    }
367}
368
369pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
370    lua: &Lua,
371    sig: &CallbackSignature<A, R>,
372    args: A,
373) -> anyhow::Result<R> {
374    let name = sig.name();
375    let decorated_name = sig.decorated_name();
376
377    match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
378        Value::Table(tbl) => {
379            for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
380                let func = func?;
381                let _timer = latency_timer(name);
382                let result: mlua::MultiValue = func.call_async(args.clone()).await?;
383                if result.is_empty() {
384                    // Continue with other handlers
385                    continue;
386                }
387                let result = R::from_lua_multi(result, lua)?;
388                return Ok(result);
389            }
390            Ok(R::default())
391        }
392        Value::Function(func) => {
393            sig.raise_error_if_allow_multiple()?;
394            let _timer = latency_timer(name);
395            Ok(func.call_async(args.clone()).await?)
396        }
397        _ => Ok(R::default()),
398    }
399}
400
401pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
402    lua: &Lua,
403    sig: &CallbackSignature<A, R>,
404    args: A,
405) -> anyhow::Result<R> {
406    let name = sig.name();
407    let decorated_name = sig.decorated_name();
408
409    match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
410        Value::Table(tbl) => {
411            for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
412                let func = func?;
413                let _timer = latency_timer(name);
414                let result: mlua::MultiValue = func.call_async(args.clone()).await?;
415                if result.is_empty() {
416                    // Continue with other handlers
417                    continue;
418                }
419                let result = R::from_lua_multi(result, lua)?;
420                return Ok(result);
421            }
422            anyhow::bail!("invalid return type for {name} event");
423        }
424        Value::Function(func) => {
425            sig.raise_error_if_allow_multiple()?;
426            let _timer = latency_timer(name);
427            Ok(func.call_async(args.clone()).await?)
428        }
429        _ => anyhow::bail!("Event {name} has not been registered"),
430    }
431}
432
433pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
434    let globals = lua.globals();
435    let package: Table = globals.get("package")?;
436    let loaded: Table = package.get("loaded")?;
437
438    let module = loaded.get(name)?;
439    match module {
440        Value::Nil => {
441            let module = lua.create_table()?;
442            loaded.set(name, module.clone())?;
443            Ok(module)
444        }
445        Value::Table(table) => Ok(table),
446        wat => anyhow::bail!(
447            "cannot register module {} as package.loaded.{} is already set to a value of type {}",
448            name,
449            name,
450            wat.type_name()
451        ),
452    }
453}
454
455/// Given a name path like `foo` or `foo.bar.baz`, sets up the module
456/// registry hierarchy to instantiate that path.
457/// Returns the leaf node of that path to allow the caller to
458/// register/assign functions etc. into it
459pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
460    let mut parent = get_or_create_module(lua, "kumo")?;
461    let mut path_so_far = String::new();
462
463    for name in name_path.split('.') {
464        if !path_so_far.is_empty() {
465            path_so_far.push('.');
466        }
467        path_so_far.push_str(name);
468
469        let sub = parent.get(name)?;
470        match sub {
471            Value::Nil => {
472                let sub = lua.create_table()?;
473                parent.set(name, sub.clone())?;
474                parent = sub;
475            }
476            Value::Table(sub) => {
477                parent = sub;
478            }
479            wat => anyhow::bail!(
480                "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
481                wat.type_name()
482            ),
483        }
484    }
485
486    Ok(parent)
487}
488
489/// Helper for mapping back to lua errors
490pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
491    mlua::Error::external(format!("{err:#}"))
492}
493
494/// Provides implementations of __pairs, __index and __len metamethods
495/// for a type that is Serialize and UserData.
496/// Neither implementation is considered to be ideal, as we must
497/// first serialize the value into a json Value which is then either
498/// iterated over, or indexed to produce the appropriate result for
499/// the metamethod.
500pub fn impl_pairs_and_index<T, M>(methods: &mut M)
501where
502    T: UserData + Serialize,
503    M: UserDataMethods<T>,
504{
505    methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
506        let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
507            return Err(mlua::Error::external("must serialize to Map"));
508        };
509
510        let mut value_iter = map.into_iter();
511
512        let iter_func = lua.create_function_mut(
513            move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
514                Some((key, value)) => {
515                    let key = lua.to_value(&key)?;
516                    let value = lua.to_value(&value)?;
517                    Ok((key, value))
518                }
519                None => Ok((Value::Nil, Value::Nil)),
520            },
521        )?;
522
523        Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
524    });
525
526    methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
527        let value = lua.to_value(this)?;
528        match value {
529            Value::Table(t) => t.get(field),
530            _ => Ok(Value::Nil),
531        }
532    });
533
534    methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
535        let value = lua.to_value(this)?;
536        match value {
537            Value::Table(v) => v.len(),
538            Value::String(v) => Ok(v.as_bytes().len() as i64),
539            _ => Ok(0),
540        }
541    });
542}
543
544/// This function will try to obtain a native lua representation
545/// of the provided value. It does this by attempting to iterate
546/// the pairs of any userdata it finds as either the value itself
547/// or the values of a table value by recursively applying
548/// materialize_to_lua_value to the value.
549/// This produces a lua value that can then be processed by the
550/// Deserialize impl on Value.
551pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
552    match value {
553        mlua::Value::UserData(ud) => {
554            let mt = ud.metatable()?;
555            let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
556                let value = ud.into_lua(lua)?;
557                return Err(mlua::Error::external(format!(
558                    "cannot materialize_to_lua_value {value:?} \
559                     because it has no __pairs metamethod"
560                )));
561            };
562            let tbl = lua.create_table()?;
563            let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
564                pairs.call(mlua::Value::UserData(ud.clone()))?;
565
566            loop {
567                let (k, v): (mlua::Value, mlua::Value) =
568                    iter_func.call((state.clone(), control))?;
569                if k.is_nil() {
570                    break;
571                }
572
573                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
574                control = k;
575            }
576
577            Ok(mlua::Value::Table(tbl))
578        }
579        mlua::Value::Table(t) => {
580            let tbl = lua.create_table()?;
581            for pair in t.pairs::<mlua::Value, mlua::Value>() {
582                let (k, v) = pair?;
583                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
584            }
585            Ok(mlua::Value::Table(tbl))
586        }
587        value => Ok(value),
588    }
589}
590
591/// Convert from a lua value to a deserializable type,
592/// with a slightly more helpful error message in case of failure.
593/// NOTE: the ", while processing" portion of the error messages generated
594/// here is coupled with a regex in typing.lua!
595pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
596where
597    R: serde::de::DeserializeOwned,
598{
599    let value_cloned = value.clone();
600    match lua.from_value(value) {
601        Ok(r) => Ok(r),
602        Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
603            Ok(materialized) => match lua.from_value(materialized.clone()) {
604                Ok(r) => Ok(r),
605                Err(err) => {
606                    let mut serializer = serde_json::Serializer::new(Vec::new());
607                    let serialized = match materialized.serialize(&mut serializer) {
608                        Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
609                        Err(err) => format!("<unable to encode as json: {err:#}>"),
610                    };
611                    Err(mlua::Error::external(format!(
612                        "{err:#}, while processing {serialized}"
613                    )))
614                }
615            },
616            Err(materialize_err) => Err(mlua::Error::external(format!(
617                "{err:#}, while processing a userdata. \
618                    Additionally, encountered {materialize_err:#} \
619                    when trying to iterate the pairs of that userdata"
620            ))),
621        },
622    }
623}
624
625/// CallbackSignature is a bit sugar to aid with statically typing event callback
626/// function invocation.
627///
628/// The idea is that you declare a signature instance that is typed
629/// with its argument tuple (A), and its return type tuple (R).
630///
631/// The signature instance can then be used to invoke the callback by name.
632///
633/// The register method allows pre-registering events so that `kumo.on`
634/// can reason about them better.  The main function enabled by this is
635/// `allow_multiple`; when that is set to true, `kumo.on` will allow
636/// recording multiple callback instances, calling them in sequence
637/// until one of them returns a value.
638pub struct CallbackSignature<A, R>
639where
640    A: IntoLuaMulti,
641    R: FromLuaMulti,
642{
643    marker: std::marker::PhantomData<(A, R)>,
644    allow_multiple: bool,
645    name: Cow<'static, str>,
646}
647
648#[linkme::distributed_slice]
649pub static CALLBACK_SIGNATURES: [fn()];
650
651/// Helper for declaring a named event handler callback signature.
652///
653/// Usage looks like:
654///
655/// ```rust,ignore
656/// declare_event! {
657/// pub static GET_Q_CONFIG_SIG: Multiple(
658///         "get_queue_config",
659///         domain: &'static str,
660///         tenant: Option<&'static str>,
661///         campaign: Option<&'static str>,
662///         routing_domain: Option<&'static str>,
663///     ) -> QueueConfig;
664/// }
665/// ```
666///
667/// A handler can be either `Single` or `Multiple`, indicating whether
668/// only a single registration or multiple registrations are permitted.
669/// The string literal is the name of the event, followed by a fn-style
670/// parameter list which names each parameter in sequence, followed by
671/// the return value.  The names are not currently used in any way,
672/// but enhance the readability of the code.
673///
674/// In addition to declaring the signature in a global, some glue
675/// is generated that will register the signature appropriately
676/// so that lua knows whether it is single or multiple and can
677/// act appropriately when `kumo.on` is called.
678#[macro_export]
679macro_rules! declare_event {
680    ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
681        $vis static $sym: ::std::sync::LazyLock<
682            $crate::CallbackSignature<($($args),*), $ret>> =
683                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
684
685        $crate::paste::paste! {
686            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
687            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
688                $sym.register();
689            };
690        }
691    };
692    ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
693        $vis static $sym: ::std::sync::LazyLock<
694            $crate::CallbackSignature<($($args),*), $ret>> =
695                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
696
697        $crate::paste::paste! {
698            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
699            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
700                $sym.register();
701            };
702        }
703    };
704}
705
706/// For each event handler CallbackSignature that was declared via
707/// `declare_event!`, call its `.register()` method to register
708/// it so that `kumo.on` can give appropriate messaging if misused,
709/// and so that runtime dispatch will work correctly.
710///
711/// This should be called once, prior to running any lua code.
712fn register_declared_events() {
713    static ONCE: Once = Once::new();
714    ONCE.call_once(|| {
715        for reg_func in CALLBACK_SIGNATURES {
716            reg_func();
717        }
718    });
719}
720
721impl<A, R> CallbackSignature<A, R>
722where
723    A: IntoLuaMulti,
724    R: FromLuaMulti,
725{
726    pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
727        let name = name.into();
728
729        Self {
730            marker: std::marker::PhantomData,
731            allow_multiple: false,
732            name,
733        }
734    }
735
736    /// Make sure that you call .register() on this from
737    /// eg: mod_kumo::register in order for it to be instantiated
738    /// and visible to the config loader
739    pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
740        let name = name.into();
741
742        Self {
743            marker: std::marker::PhantomData,
744            allow_multiple: true,
745            name,
746        }
747    }
748
749    pub fn register(&self) {
750        if self.allow_multiple {
751            CALLBACK_ALLOWS_MULTIPLE
752                .lock()
753                .insert(self.name.to_string());
754        }
755    }
756
757    pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
758        anyhow::ensure!(
759            !self.allow_multiple(),
760            "handler {} is set to allow multiple handlers \
761                    but is registered with a single instance. This indicates that \
762                    register() was not called on the signature when initializing \
763                    the lua context. Please report this issue to the KumoMTA team!",
764            self.name
765        );
766        Ok(())
767    }
768
769    /// Return true if this signature allows multiple instances to be registered
770    /// and called.
771    pub fn allow_multiple(&self) -> bool {
772        self.allow_multiple
773    }
774
775    pub fn name(&self) -> &str {
776        &self.name
777    }
778
779    pub fn decorated_name(&self) -> String {
780        decorate_callback_name(&self.name)
781    }
782}
783
784pub fn does_callback_allow_multiple(name: &str) -> bool {
785    CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
786}
787
788pub fn decorate_callback_name(name: &str) -> String {
789    format!("kumomta-on-{name}")
790}
791
792pub fn serialize_options() -> mlua::SerializeOptions {
793    mlua::SerializeOptions::new()
794        .serialize_none_to_null(false)
795        .serialize_unit_to_null(false)
796}