config/
lib.rs

1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use mlua::{
6    FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
7    UserData, UserDataMethods, Value,
8};
9use parking_lot::FairMutex as Mutex;
10pub use pastey as paste;
11use prometheus::{CounterVec, HistogramTimer, HistogramVec};
12use serde::Serialize;
13use std::borrow::Cow;
14use std::collections::HashSet;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{LazyLock, Once};
18use std::time::Instant;
19
20pub mod epoch;
21mod pool;
22
23static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
24static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
25static LUA_LOAD_COUNT: LazyLock<metrics::Counter> = LazyLock::new(|| {
26    metrics::describe_counter!(
27        "lua_load_count",
28        "how many times the policy lua script has been \
29         loaded into a new context"
30    );
31    metrics::counter!("lua_load_count")
32});
33static LUA_COUNT: LazyLock<metrics::Gauge> = LazyLock::new(|| {
34    metrics::describe_gauge!("lua_count", "the number of lua contexts currently alive");
35    metrics::gauge!("lua_count")
36});
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38    LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42static LATENCY_HIST: LazyLock<HistogramVec> = LazyLock::new(|| {
43    prometheus::register_histogram_vec!(
44        "lua_event_latency",
45        "how long a given lua event callback took",
46        &["event"]
47    )
48    .unwrap()
49});
50static EVENT_STARTED_COUNT: LazyLock<CounterVec> = LazyLock::new(|| {
51    prometheus::register_counter_vec!(
52        "lua_event_started",
53        "Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events",
54        &["event"]
55    )
56    .unwrap()
57});
58
59pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
60
61fn latency_timer(label: &str) -> HistogramTimer {
62    EVENT_STARTED_COUNT
63        .get_metric_with_label_values(&[label])
64        .expect("to get counter")
65        .inc();
66    LATENCY_HIST
67        .get_metric_with_label_values(&[label])
68        .expect("to get histo")
69        .start_timer()
70}
71
72#[derive(Debug)]
73struct LuaConfigInner {
74    lua: Lua,
75    created: Instant,
76    use_count: usize,
77    epoch: ConfigEpoch,
78}
79
80impl Drop for LuaConfigInner {
81    fn drop(&mut self) {
82        LUA_COUNT.decrement(1.);
83    }
84}
85
86#[derive(Debug)]
87pub struct LuaConfig {
88    inner: Option<LuaConfigInner>,
89}
90
91pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
92    POLICY_FILE.lock().replace(path);
93    let config = load_config().await?;
94    config.put();
95    Ok(())
96}
97
98fn get_policy_path() -> Option<PathBuf> {
99    POLICY_FILE.lock().clone()
100}
101
102fn get_funcs() -> Vec<RegisterFunc> {
103    FUNCS.lock().clone()
104}
105pub fn is_validating() -> bool {
106    VALIDATE_ONLY.load(Ordering::Relaxed)
107}
108
109pub fn validation_failed() -> bool {
110    VALIDATION_FAILED.load(Ordering::Relaxed)
111}
112
113pub fn set_validation_failed() {
114    VALIDATION_FAILED.store(true, Ordering::Relaxed)
115}
116
117pub async fn load_config() -> anyhow::Result<LuaConfig> {
118    if let Some(pool) = pool_get() {
119        return Ok(pool);
120    }
121
122    LUA_LOAD_COUNT.increment(1);
123    let lua = Lua::new();
124    let created = Instant::now();
125    let epoch = get_current_epoch();
126
127    {
128        let globals = lua.globals();
129
130        if is_validating() {
131            globals.set("_VALIDATING_CONFIG", true)?;
132        }
133
134        let package: Table = globals.get("package")?;
135        let package_path: String = package.get("path")?;
136        let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
137
138        fn prefix_path(array: &mut Vec<String>, path: &str) {
139            array.insert(0, format!("{}/?.lua", path));
140            array.insert(1, format!("{}/?/init.lua", path));
141        }
142
143        prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
144        prefix_path(&mut path_array, "/opt/kumomta/share");
145
146        #[cfg(debug_assertions)]
147        prefix_path(&mut path_array, "assets");
148
149        package.set("path", path_array.join(";"))?;
150    }
151
152    register_declared_events();
153
154    for func in get_funcs() {
155        (func)(&lua)?;
156    }
157
158    if let Some(policy) = get_policy_path() {
159        let code = tokio::fs::read_to_string(&policy)
160            .await
161            .with_context(|| format!("reading policy file {policy:?}"))?;
162
163        let func = {
164            let chunk = lua.load(&code);
165            let chunk = chunk.set_name(policy.to_string_lossy());
166            chunk.into_function()?
167        };
168
169        let _timer = latency_timer("context-creation");
170        func.call_async::<()>(()).await?;
171    }
172    LUA_COUNT.increment(1.);
173
174    Ok(LuaConfig {
175        inner: Some(LuaConfigInner {
176            lua,
177            created,
178            use_count: 1,
179            epoch,
180        }),
181    })
182}
183
184pub fn register(func: RegisterFunc) {
185    FUNCS.lock().push(func);
186}
187
188impl LuaConfig {
189    fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
190        self.inner
191            .as_mut()
192            .unwrap()
193            .lua
194            .globals()
195            .set("_KUMO_CURRENT_EVENT", name.to_string())
196    }
197
198    /// Convert an array of args into a MultiValue that can be passed
199    /// to a callback signature
200    pub fn convert_args_to_multi<A: Serialize>(
201        &self,
202        args: &[A],
203    ) -> anyhow::Result<mlua::MultiValue> {
204        let lua = self.inner.as_ref().unwrap();
205        let mut arg_vec = vec![];
206        for a in args.iter() {
207            arg_vec.push(lua.lua.to_value(a)?);
208        }
209        Ok(mlua::MultiValue::from_vec(arg_vec))
210    }
211
212    /// Intended to be used together with kumo.spawn_task
213    pub async fn convert_args_and_call_callback<A: Serialize>(
214        &mut self,
215        sig: &CallbackSignature<Value, ()>,
216        args: A,
217    ) -> anyhow::Result<()> {
218        let lua = self.inner.as_mut().unwrap();
219        let args = lua.lua.to_value(&args)?;
220
221        let name = sig.name();
222        let decorated_name = sig.decorated_name();
223
224        match lua
225            .lua
226            .named_registry_value::<mlua::Function>(&decorated_name)
227        {
228            Ok(func) => {
229                let _timer = latency_timer(name);
230                Ok(func.call_async(args).await?)
231            }
232            _ => anyhow::bail!("{name} has not been registered"),
233        }
234    }
235
236    /// Explicitly put the config object back into its containing pool.
237    /// Ideally we'd do this automatically when the object is dropped,
238    /// but lua's garbage collection makes this problematic:
239    /// if a future whose graph contains an async lua call within
240    /// this config object is cancelled (eg: simply stopped without
241    /// calling it again), and the config object is not explicitly garbage
242    /// collected, any futures and data owned by any dependencies of
243    /// the cancelled future remain alive until the next gc run,
244    /// which can cause things like async locks and semaphores to
245    /// have a lifetime extended by the maximum age of the lua context.
246    ///
247    /// The combat this, consumers of LuaConfig should explicitly
248    /// call `config.put()` after successfully using the config
249    /// object.
250    ///
251    /// Or framing it another way: consumers must not call `config.put()`
252    /// if a transitive dep might have been cancelled.
253    pub fn put(mut self) {
254        if let Some(inner) = self.inner.take() {
255            pool_put(inner);
256        }
257    }
258
259    pub async fn call_callback<'a, A: IntoLuaMulti + Clone, R: FromLuaMulti>(
260        &mut self,
261        sig: &'a CallbackSignature<A, R>,
262        args: A,
263    ) -> anyhow::Result<CallbackDisposition<'a, R>> {
264        let name = sig.name();
265        self.set_current_event(name)?;
266        let lua = self.inner.as_mut().unwrap();
267        sig.call(&lua.lua, args).await
268    }
269
270    pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
271        &mut self,
272        sig: &CallbackSignature<A, R>,
273        args: A,
274    ) -> anyhow::Result<R> {
275        let name = sig.name();
276        self.set_current_event(name)?;
277        let lua = self.inner.as_mut().unwrap();
278        Ok(sig.call(&lua.lua, args).await?.or_default())
279    }
280
281    pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
282        &mut self,
283        sig: &CallbackSignature<A, R>,
284        args: A,
285    ) -> anyhow::Result<R> {
286        let name = sig.name();
287        self.set_current_event(name)?;
288        let lua = self.inner.as_mut().unwrap();
289        sig.call(&lua.lua, args).await?.require_value()
290    }
291
292    pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
293        &mut self,
294        sig: &CallbackSignature<A, Option<R>>,
295        args: A,
296    ) -> anyhow::Result<Option<R>> {
297        let name = sig.name();
298        self.set_current_event(name)?;
299        let lua = self.inner.as_mut().unwrap();
300        let result = sig.call(&lua.lua, args).await?;
301        match result.result {
302            None => Ok(None),
303            Some(result) => Ok(result),
304        }
305    }
306
307    pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
308        Ok(self
309            .inner
310            .as_mut()
311            .unwrap()
312            .lua
313            .remove_registry_value(value)?)
314    }
315
316    /// Call a constructor registered via `on`. Returns a registry key that can be
317    /// used to reference the returned value again later on this same Lua instance
318    pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
319        &mut self,
320        sig: &CallbackSignature<A, Value>,
321        args: A,
322    ) -> anyhow::Result<RegistryKey> {
323        let name = sig.name();
324        anyhow::ensure!(
325            !sig.allow_multiple(),
326            "ctor event signature for {name} is defined as allow_multiple, which is not supported"
327        );
328
329        let decorated_name = sig.decorated_name();
330        self.set_current_event(name)?;
331
332        let inner = self.inner.as_mut().unwrap();
333
334        let func = inner
335            .lua
336            .named_registry_value::<mlua::Function>(&decorated_name)?;
337
338        let _timer = latency_timer(name);
339        let value: Value = func.call_async(args.clone()).await?;
340        drop(func);
341
342        Ok(inner.lua.create_registry_value(value)?)
343    }
344
345    /// Operate on an object/value that was previously constructed via
346    /// async_call_ctor.
347    pub async fn with_registry_value<F, R, FUT>(
348        &mut self,
349        value: &RegistryKey,
350        func: F,
351    ) -> anyhow::Result<R>
352    where
353        R: FromLuaMulti,
354        F: FnOnce(Value) -> anyhow::Result<FUT>,
355        FUT: std::future::Future<Output = anyhow::Result<R>>,
356    {
357        let inner = self.inner.as_mut().unwrap();
358        let value = inner.lua.registry_value(value)?;
359        let future = (func)(value)?;
360        future.await
361    }
362}
363
364pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
365    let globals = lua.globals();
366    let package: Table = globals.get("package")?;
367    let loaded: Table = package.get("loaded")?;
368
369    let module = loaded.get(name)?;
370    match module {
371        Value::Nil => {
372            let module = lua.create_table()?;
373            loaded.set(name, module.clone())?;
374            Ok(module)
375        }
376        Value::Table(table) => Ok(table),
377        wat => anyhow::bail!(
378            "cannot register module {} as package.loaded.{} is already set to a value of type {}",
379            name,
380            name,
381            wat.type_name()
382        ),
383    }
384}
385
386/// Given a name path like `foo` or `foo.bar.baz`, sets up the module
387/// registry hierarchy to instantiate that path.
388/// Returns the leaf node of that path to allow the caller to
389/// register/assign functions etc. into it
390pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
391    let mut parent = get_or_create_module(lua, "kumo")?;
392    let mut path_so_far = String::new();
393
394    for name in name_path.split('.') {
395        if !path_so_far.is_empty() {
396            path_so_far.push('.');
397        }
398        path_so_far.push_str(name);
399
400        let sub = parent.get(name)?;
401        match sub {
402            Value::Nil => {
403                let sub = lua.create_table()?;
404                parent.set(name, sub.clone())?;
405                parent = sub;
406            }
407            Value::Table(sub) => {
408                parent = sub;
409            }
410            wat => anyhow::bail!(
411                "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
412                wat.type_name()
413            ),
414        }
415    }
416
417    Ok(parent)
418}
419
420/// Helper for mapping back to lua errors
421pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
422    mlua::Error::external(format!("{err:#}"))
423}
424
425/// Provides implementations of __pairs, __index and __len metamethods
426/// for a type that is Serialize and UserData.
427/// Neither implementation is considered to be ideal, as we must
428/// first serialize the value into a json Value which is then either
429/// iterated over, or indexed to produce the appropriate result for
430/// the metamethod.
431pub fn impl_pairs_and_index<T, M>(methods: &mut M)
432where
433    T: UserData + Serialize,
434    M: UserDataMethods<T>,
435{
436    methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
437        let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
438            return Err(mlua::Error::external("must serialize to Map"));
439        };
440
441        let mut value_iter = map.into_iter();
442
443        let iter_func = lua.create_function_mut(
444            move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
445                Some((key, value)) => {
446                    let key = lua.to_value(&key)?;
447                    let value = lua.to_value(&value)?;
448                    Ok((key, value))
449                }
450                None => Ok((Value::Nil, Value::Nil)),
451            },
452        )?;
453
454        Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
455    });
456
457    methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
458        let value = lua.to_value(this)?;
459        match value {
460            Value::Table(t) => t.get(field),
461            _ => Ok(Value::Nil),
462        }
463    });
464
465    methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
466        let value = lua.to_value(this)?;
467        match value {
468            Value::Table(v) => v.len(),
469            Value::String(v) => Ok(v.as_bytes().len() as i64),
470            _ => Ok(0),
471        }
472    });
473}
474
475/// This function will try to obtain a native lua representation
476/// of the provided value. It does this by attempting to iterate
477/// the pairs of any userdata it finds as either the value itself
478/// or the values of a table value by recursively applying
479/// materialize_to_lua_value to the value.
480/// This produces a lua value that can then be processed by the
481/// Deserialize impl on Value.
482pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
483    match value {
484        mlua::Value::UserData(ud) => {
485            let mt = ud.metatable()?;
486            let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
487                let value = ud.into_lua(lua)?;
488                return Err(mlua::Error::external(format!(
489                    "cannot materialize_to_lua_value {value:?} \
490                     because it has no __pairs metamethod"
491                )));
492            };
493            let tbl = lua.create_table()?;
494            let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
495                pairs.call(mlua::Value::UserData(ud.clone()))?;
496
497            loop {
498                let (k, v): (mlua::Value, mlua::Value) =
499                    iter_func.call((state.clone(), control))?;
500                if k.is_nil() {
501                    break;
502                }
503
504                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
505                control = k;
506            }
507
508            Ok(mlua::Value::Table(tbl))
509        }
510        mlua::Value::Table(t) => {
511            let tbl = lua.create_table()?;
512            for pair in t.pairs::<mlua::Value, mlua::Value>() {
513                let (k, v) = pair?;
514                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
515            }
516            Ok(mlua::Value::Table(tbl))
517        }
518        value => Ok(value),
519    }
520}
521
522/// Helper wrapper type for passing/returning serde encoded values from/to lua
523pub struct SerdeWrappedValue<T>(pub T);
524
525impl<T: serde::Serialize> serde::Serialize for SerdeWrappedValue<T> {
526    fn serialize<S>(
527        &self,
528        s: S,
529    ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
530    where
531        S: serde::Serializer,
532    {
533        self.0.serialize(s)
534    }
535}
536
537impl<T: Clone> Clone for SerdeWrappedValue<T> {
538    fn clone(&self) -> Self {
539        SerdeWrappedValue(self.0.clone())
540    }
541}
542
543impl<T: Default> Default for SerdeWrappedValue<T> {
544    fn default() -> Self {
545        SerdeWrappedValue(Default::default())
546    }
547}
548
549impl<T: serde::Serialize> SerdeWrappedValue<T> {
550    pub fn to_lua_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
551        lua.to_value_with(&self.0, serialize_options())
552    }
553}
554
555impl<T: serde::Serialize> IntoLua for SerdeWrappedValue<T> {
556    fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
557        lua.to_value_with(&self.0, serialize_options())
558    }
559}
560
561impl<T: serde::de::DeserializeOwned> FromLua for SerdeWrappedValue<T> {
562    fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<SerdeWrappedValue<T>> {
563        let inner: T = from_lua_value(lua, value)?;
564        Ok(SerdeWrappedValue(inner))
565    }
566}
567
568impl<T> std::ops::Deref for SerdeWrappedValue<T> {
569    type Target = T;
570    fn deref(&self) -> &T {
571        &self.0
572    }
573}
574
575impl<T> std::ops::DerefMut for SerdeWrappedValue<T> {
576    fn deref_mut(&mut self) -> &mut T {
577        &mut self.0
578    }
579}
580
581/// Convert from a lua value to a deserializable type,
582/// with a slightly more helpful error message in case of failure.
583/// NOTE: the ", while processing" portion of the error messages generated
584/// here is coupled with a regex in typing.lua!
585pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
586where
587    R: serde::de::DeserializeOwned,
588{
589    let value_cloned = value.clone();
590    match lua.from_value(value) {
591        Ok(r) => Ok(r),
592        Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
593            Ok(materialized) => match lua.from_value(materialized.clone()) {
594                Ok(r) => Ok(r),
595                Err(err) => {
596                    let mut serializer = serde_json::Serializer::new(Vec::new());
597                    let serialized = match materialized.serialize(&mut serializer) {
598                        Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
599                        Err(err) => format!("<unable to encode as json: {err:#}>"),
600                    };
601                    Err(mlua::Error::external(format!(
602                        "{err:#}, while processing {serialized}"
603                    )))
604                }
605            },
606            Err(materialize_err) => Err(mlua::Error::external(format!(
607                "{err:#}, while processing a userdata. \
608                    Additionally, encountered {materialize_err:#} \
609                    when trying to iterate the pairs of that userdata"
610            ))),
611        },
612    }
613}
614
615/// CallbackSignature is a bit sugar to aid with statically typing event callback
616/// function invocation.
617///
618/// The idea is that you declare a signature instance that is typed
619/// with its argument tuple (A), and its return type tuple (R).
620///
621/// The signature instance can then be used to invoke the callback by name.
622///
623/// The register method allows pre-registering events so that `kumo.on`
624/// can reason about them better.  The main function enabled by this is
625/// `allow_multiple`; when that is set to true, `kumo.on` will allow
626/// recording multiple callback instances, calling them in sequence
627/// until one of them returns a value.
628pub struct CallbackSignature<A, R>
629where
630    A: IntoLuaMulti,
631    R: FromLuaMulti,
632{
633    marker: std::marker::PhantomData<(A, R)>,
634    allow_multiple: bool,
635    name: Cow<'static, str>,
636}
637
638#[linkme::distributed_slice]
639pub static CALLBACK_SIGNATURES: [fn()];
640
641/// Helper for declaring a named event handler callback signature.
642///
643/// Usage looks like:
644///
645/// ```rust,ignore
646/// declare_event! {
647/// pub static GET_Q_CONFIG_SIG: Multiple(
648///         "get_queue_config",
649///         domain: &'static str,
650///         tenant: Option<&'static str>,
651///         campaign: Option<&'static str>,
652///         routing_domain: Option<&'static str>,
653///     ) -> QueueConfig;
654/// }
655/// ```
656///
657/// A handler can be either `Single` or `Multiple`, indicating whether
658/// only a single registration or multiple registrations are permitted.
659/// The string literal is the name of the event, followed by a fn-style
660/// parameter list which names each parameter in sequence, followed by
661/// the return value.  The names are not currently used in any way,
662/// but enhance the readability of the code.
663///
664/// In addition to declaring the signature in a global, some glue
665/// is generated that will register the signature appropriately
666/// so that lua knows whether it is single or multiple and can
667/// act appropriately when `kumo.on` is called.
668#[macro_export]
669macro_rules! declare_event {
670    ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
671        $vis static $sym: ::std::sync::LazyLock<
672            $crate::CallbackSignature<($($args),*), $ret>> =
673                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
674
675        $crate::paste::paste! {
676            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
677            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
678                $sym.register();
679            };
680        }
681    };
682    ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
683        $vis static $sym: ::std::sync::LazyLock<
684            $crate::CallbackSignature<($($args),*), $ret>> =
685                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
686
687        $crate::paste::paste! {
688            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
689            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
690                $sym.register();
691            };
692        }
693    };
694}
695
696/// For each event handler CallbackSignature that was declared via
697/// `declare_event!`, call its `.register()` method to register
698/// it so that `kumo.on` can give appropriate messaging if misused,
699/// and so that runtime dispatch will work correctly.
700///
701/// This should be called once, prior to running any lua code.
702fn register_declared_events() {
703    static ONCE: Once = Once::new();
704    ONCE.call_once(|| {
705        for reg_func in CALLBACK_SIGNATURES {
706            reg_func();
707        }
708    });
709}
710
711impl<A, R> CallbackSignature<A, R>
712where
713    A: IntoLuaMulti,
714    R: FromLuaMulti,
715{
716    pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
717        let name = name.into();
718
719        Self {
720            marker: std::marker::PhantomData,
721            allow_multiple: false,
722            name,
723        }
724    }
725
726    /// Make sure that you call .register() on this from
727    /// eg: mod_kumo::register in order for it to be instantiated
728    /// and visible to the config loader
729    pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
730        let name = name.into();
731
732        Self {
733            marker: std::marker::PhantomData,
734            allow_multiple: true,
735            name,
736        }
737    }
738
739    pub fn register(&self) {
740        if self.allow_multiple {
741            CALLBACK_ALLOWS_MULTIPLE
742                .lock()
743                .insert(self.name.to_string());
744        }
745    }
746
747    pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
748        anyhow::ensure!(
749            !self.allow_multiple(),
750            "handler {} is set to allow multiple handlers \
751                    but is registered with a single instance. This indicates that \
752                    register() was not called on the signature when initializing \
753                    the lua context. Please report this issue to the KumoMTA team!",
754            self.name
755        );
756        Ok(())
757    }
758
759    /// Return true if this signature allows multiple instances to be registered
760    /// and called.
761    pub fn allow_multiple(&self) -> bool {
762        self.allow_multiple
763    }
764
765    pub fn name(&self) -> &str {
766        &self.name
767    }
768
769    pub fn decorated_name(&self) -> String {
770        decorate_callback_name(&self.name)
771    }
772}
773
774impl<A, R> CallbackSignature<A, R>
775where
776    A: IntoLuaMulti + Clone,
777    R: FromLuaMulti,
778{
779    /// Calls the callback, passing in the supplied arguments.
780    /// Returns the callback disposition which allows further
781    /// intrepretation of the result, such as deciding whether
782    /// to Default the value based on whether the callback
783    /// was defined or not
784    pub async fn call<'a>(
785        &'a self,
786        lua: &Lua,
787        args: A,
788    ) -> anyhow::Result<CallbackDisposition<'a, R>> {
789        let name = self.name();
790        let decorated_name = self.decorated_name();
791
792        match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
793            Value::Table(tbl) => {
794                for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
795                    let func = func?;
796                    let _timer = latency_timer(name);
797                    let result: mlua::MultiValue = func.call_async(args.clone()).await?;
798                    if result.is_empty() {
799                        // Continue with other handlers
800                        continue;
801                    }
802                    let result = R::from_lua_multi(result, lua)?;
803                    return Ok(CallbackDisposition {
804                        handler_was_defined: true,
805                        result: Some(result),
806                        event_name: name,
807                    });
808                }
809                Ok(CallbackDisposition {
810                    handler_was_defined: false,
811                    result: None,
812                    event_name: name,
813                })
814            }
815            Value::Function(func) => {
816                self.raise_error_if_allow_multiple()?;
817                let _timer = latency_timer(name);
818                Ok(CallbackDisposition {
819                    handler_was_defined: true,
820                    result: Some(func.call_async(args.clone()).await?),
821                    event_name: name,
822                })
823            }
824            _ => Ok(CallbackDisposition {
825                handler_was_defined: false,
826                result: None,
827                event_name: name,
828            }),
829        }
830    }
831}
832
833pub fn does_callback_allow_multiple(name: &str) -> bool {
834    CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
835}
836
837pub fn decorate_callback_name(name: &str) -> String {
838    format!("kumomta-on-{name}")
839}
840
841/// Allows reasoning about the result of a callback to decide
842/// how to interpret the result
843pub struct CallbackDisposition<'a, T> {
844    /// Indicates whether the handler was found. If false,
845    /// then result will also be None.
846    /// If true and result.is_none(), it indicates that
847    /// none of the handlers returned a value.
848    pub handler_was_defined: bool,
849    /// The result!
850    pub result: Option<T>,
851    /// Which event was called (or to be called)
852    pub event_name: &'a str,
853}
854
855impl<'a, T> CallbackDisposition<'a, T> {
856    /// Requires that a value be returned.  Translates the disposition
857    /// into appropriate error messaging if no value was returned.
858    pub fn require_value(mut self) -> anyhow::Result<T> {
859        if !self.handler_was_defined {
860            anyhow::bail!("Event {} has not been registered", self.event_name);
861        }
862        match self.result.take() {
863            Some(value) => Ok(value),
864            None => anyhow::bail!("Event {} did not return a value", self.event_name),
865        }
866    }
867}
868impl<'a, T: Default> CallbackDisposition<'a, T> {
869    /// Unwraps the value type, or the Default impl if no value was returned
870    pub fn or_default(self) -> T {
871        self.result.unwrap_or_default()
872    }
873}
874
875pub fn serialize_options() -> mlua::SerializeOptions {
876    mlua::SerializeOptions::new()
877        .serialize_none_to_null(false)
878        .serialize_unit_to_null(false)
879}