config/
lib.rs

1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use kumo_prometheus::declare_metric;
6use kumo_prometheus::prometheus::HistogramTimer;
7use mlua::{
8    FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
9    UserData, UserDataMethods, Value,
10};
11use parking_lot::FairMutex as Mutex;
12pub use pastey as paste;
13use serde::Serialize;
14use std::borrow::Cow;
15use std::collections::HashSet;
16use std::path::PathBuf;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{LazyLock, Once};
19use std::time::Instant;
20
21pub mod epoch;
22mod pool;
23
24static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
25static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
26
27declare_metric! {
28/// how many times the policy lua script has been loaded into a new context
29static LUA_LOAD_COUNT: IntCounter("lua_load_count");
30}
31
32declare_metric! {
33/// the number of lua contexts currently alive
34static LUA_COUNT: IntGauge("lua_count");
35}
36
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38    LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42
43declare_metric! {
44/// how long a given lua event callback took
45static LATENCY_HIST: HistogramVec(
46        "lua_event_latency",
47        &["event"]);
48}
49
50declare_metric! {
51/// Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events
52static EVENT_STARTED_COUNT: CounterVec(
53        "lua_event_started",
54        &["event"]
55    );
56}
57
58pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
59
60fn latency_timer(label: &str) -> HistogramTimer {
61    EVENT_STARTED_COUNT
62        .get_metric_with_label_values(&[label])
63        .expect("to get counter")
64        .inc();
65    LATENCY_HIST
66        .get_metric_with_label_values(&[label])
67        .expect("to get histo")
68        .start_timer()
69}
70
71#[derive(Debug)]
72struct LuaConfigInner {
73    lua: Lua,
74    created: Instant,
75    use_count: usize,
76    epoch: ConfigEpoch,
77}
78
79impl Drop for LuaConfigInner {
80    fn drop(&mut self) {
81        LUA_COUNT.dec();
82    }
83}
84
85#[derive(Debug)]
86pub struct LuaConfig {
87    inner: Option<LuaConfigInner>,
88}
89
90pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
91    POLICY_FILE.lock().replace(path);
92    let config = load_config().await?;
93    config.put();
94    Ok(())
95}
96
97fn get_policy_path() -> Option<PathBuf> {
98    POLICY_FILE.lock().clone()
99}
100
101fn get_funcs() -> Vec<RegisterFunc> {
102    FUNCS.lock().clone()
103}
104pub fn is_validating() -> bool {
105    VALIDATE_ONLY.load(Ordering::Relaxed)
106}
107
108pub fn validation_failed() -> bool {
109    VALIDATION_FAILED.load(Ordering::Relaxed)
110}
111
112pub fn set_validation_failed() {
113    VALIDATION_FAILED.store(true, Ordering::Relaxed)
114}
115
116pub async fn load_config() -> anyhow::Result<LuaConfig> {
117    if let Some(pool) = pool_get() {
118        return Ok(pool);
119    }
120
121    LUA_LOAD_COUNT.inc();
122    let lua = Lua::new();
123    let created = Instant::now();
124    let epoch = get_current_epoch();
125
126    {
127        let globals = lua.globals();
128
129        if is_validating() {
130            globals.set("_VALIDATING_CONFIG", true)?;
131        }
132
133        let package: Table = globals.get("package")?;
134        let package_path: String = package.get("path")?;
135        let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
136
137        fn prefix_path(array: &mut Vec<String>, path: &str) {
138            array.insert(0, format!("{}/?.lua", path));
139            array.insert(1, format!("{}/?/init.lua", path));
140        }
141
142        prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
143        prefix_path(&mut path_array, "/opt/kumomta/share");
144
145        #[cfg(debug_assertions)]
146        prefix_path(&mut path_array, "assets");
147
148        package.set("path", path_array.join(";"))?;
149    }
150
151    register_declared_events();
152
153    for func in get_funcs() {
154        (func)(&lua)?;
155    }
156
157    if let Some(policy) = get_policy_path() {
158        let code = tokio::fs::read_to_string(&policy)
159            .await
160            .with_context(|| format!("reading policy file {policy:?}"))?;
161
162        let func = {
163            let chunk = lua.load(&code);
164            let chunk = chunk.set_name(policy.to_string_lossy());
165            chunk.into_function()?
166        };
167
168        let _timer = latency_timer("context-creation");
169        func.call_async::<()>(()).await?;
170    }
171    LUA_COUNT.inc();
172
173    Ok(LuaConfig {
174        inner: Some(LuaConfigInner {
175            lua,
176            created,
177            use_count: 1,
178            epoch,
179        }),
180    })
181}
182
183pub fn register(func: RegisterFunc) {
184    FUNCS.lock().push(func);
185}
186
187impl LuaConfig {
188    fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
189        self.inner
190            .as_mut()
191            .unwrap()
192            .lua
193            .globals()
194            .set("_KUMO_CURRENT_EVENT", name.to_string())
195    }
196
197    /// Convert an array of args into a MultiValue that can be passed
198    /// to a callback signature
199    pub fn convert_args_to_multi<A: Serialize>(
200        &self,
201        args: &[A],
202    ) -> anyhow::Result<mlua::MultiValue> {
203        let lua = self.inner.as_ref().unwrap();
204        let mut arg_vec = vec![];
205        for a in args.iter() {
206            arg_vec.push(lua.lua.to_value(a)?);
207        }
208        Ok(mlua::MultiValue::from_vec(arg_vec))
209    }
210
211    /// Intended to be used together with kumo.spawn_task
212    pub async fn convert_args_and_call_callback<A: Serialize>(
213        &mut self,
214        sig: &CallbackSignature<Value, ()>,
215        args: A,
216    ) -> anyhow::Result<()> {
217        let lua = self.inner.as_mut().unwrap();
218        let args = lua.lua.to_value(&args)?;
219
220        let name = sig.name();
221        let decorated_name = sig.decorated_name();
222
223        match lua
224            .lua
225            .named_registry_value::<mlua::Function>(&decorated_name)
226        {
227            Ok(func) => {
228                let _timer = latency_timer(name);
229                Ok(func.call_async(args).await?)
230            }
231            _ => anyhow::bail!("{name} has not been registered"),
232        }
233    }
234
235    /// Explicitly put the config object back into its containing pool.
236    /// Ideally we'd do this automatically when the object is dropped,
237    /// but lua's garbage collection makes this problematic:
238    /// if a future whose graph contains an async lua call within
239    /// this config object is cancelled (eg: simply stopped without
240    /// calling it again), and the config object is not explicitly garbage
241    /// collected, any futures and data owned by any dependencies of
242    /// the cancelled future remain alive until the next gc run,
243    /// which can cause things like async locks and semaphores to
244    /// have a lifetime extended by the maximum age of the lua context.
245    ///
246    /// The combat this, consumers of LuaConfig should explicitly
247    /// call `config.put()` after successfully using the config
248    /// object.
249    ///
250    /// Or framing it another way: consumers must not call `config.put()`
251    /// if a transitive dep might have been cancelled.
252    pub fn put(mut self) {
253        if let Some(inner) = self.inner.take() {
254            pool_put(inner);
255        }
256    }
257
258    pub async fn call_callback<'a, A: IntoLuaMulti + Clone, R: FromLuaMulti>(
259        &mut self,
260        sig: &'a CallbackSignature<A, R>,
261        args: A,
262    ) -> anyhow::Result<CallbackDisposition<'a, R>> {
263        let name = sig.name();
264        self.set_current_event(name)?;
265        let lua = self.inner.as_mut().unwrap();
266        sig.call(&lua.lua, args).await
267    }
268
269    pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
270        &mut self,
271        sig: &CallbackSignature<A, R>,
272        args: A,
273    ) -> anyhow::Result<R> {
274        let name = sig.name();
275        self.set_current_event(name)?;
276        let lua = self.inner.as_mut().unwrap();
277        Ok(sig.call(&lua.lua, args).await?.or_default())
278    }
279
280    pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
281        &mut self,
282        sig: &CallbackSignature<A, R>,
283        args: A,
284    ) -> anyhow::Result<R> {
285        let name = sig.name();
286        self.set_current_event(name)?;
287        let lua = self.inner.as_mut().unwrap();
288        sig.call(&lua.lua, args).await?.require_value()
289    }
290
291    pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
292        &mut self,
293        sig: &CallbackSignature<A, Option<R>>,
294        args: A,
295    ) -> anyhow::Result<Option<R>> {
296        let name = sig.name();
297        self.set_current_event(name)?;
298        let lua = self.inner.as_mut().unwrap();
299        let result = sig.call(&lua.lua, args).await?;
300        match result.result {
301            None => Ok(None),
302            Some(result) => Ok(result),
303        }
304    }
305
306    pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
307        Ok(self
308            .inner
309            .as_mut()
310            .unwrap()
311            .lua
312            .remove_registry_value(value)?)
313    }
314
315    /// Call a constructor registered via `on`. Returns a registry key that can be
316    /// used to reference the returned value again later on this same Lua instance
317    pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
318        &mut self,
319        sig: &CallbackSignature<A, Value>,
320        args: A,
321    ) -> anyhow::Result<RegistryKey> {
322        let name = sig.name();
323        anyhow::ensure!(
324            !sig.allow_multiple(),
325            "ctor event signature for {name} is defined as allow_multiple, which is not supported"
326        );
327
328        let decorated_name = sig.decorated_name();
329        self.set_current_event(name)?;
330
331        let inner = self.inner.as_mut().unwrap();
332
333        let func = inner
334            .lua
335            .named_registry_value::<mlua::Function>(&decorated_name)?;
336
337        let _timer = latency_timer(name);
338        let value: Value = func.call_async(args.clone()).await?;
339        drop(func);
340
341        Ok(inner.lua.create_registry_value(value)?)
342    }
343
344    /// Operate on an object/value that was previously constructed via
345    /// async_call_ctor.
346    pub async fn with_registry_value<F, R, FUT>(
347        &mut self,
348        value: &RegistryKey,
349        func: F,
350    ) -> anyhow::Result<R>
351    where
352        R: FromLuaMulti,
353        F: FnOnce(Value) -> anyhow::Result<FUT>,
354        FUT: std::future::Future<Output = anyhow::Result<R>>,
355    {
356        let inner = self.inner.as_mut().unwrap();
357        let value = inner.lua.registry_value(value)?;
358        let future = (func)(value)?;
359        future.await
360    }
361}
362
363pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
364    let globals = lua.globals();
365    let package: Table = globals.get("package")?;
366    let loaded: Table = package.get("loaded")?;
367
368    let module = loaded.get(name)?;
369    match module {
370        Value::Nil => {
371            let module = lua.create_table()?;
372            loaded.set(name, module.clone())?;
373            Ok(module)
374        }
375        Value::Table(table) => Ok(table),
376        wat => anyhow::bail!(
377            "cannot register module {} as package.loaded.{} is already set to a value of type {}",
378            name,
379            name,
380            wat.type_name()
381        ),
382    }
383}
384
385/// Given a name path like `foo` or `foo.bar.baz`, sets up the module
386/// registry hierarchy to instantiate that path.
387/// Returns the leaf node of that path to allow the caller to
388/// register/assign functions etc. into it
389pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
390    let mut parent = get_or_create_module(lua, "kumo")?;
391    let mut path_so_far = String::new();
392
393    for name in name_path.split('.') {
394        if !path_so_far.is_empty() {
395            path_so_far.push('.');
396        }
397        path_so_far.push_str(name);
398
399        let sub = parent.get(name)?;
400        match sub {
401            Value::Nil => {
402                let sub = lua.create_table()?;
403                parent.set(name, sub.clone())?;
404                parent = sub;
405            }
406            Value::Table(sub) => {
407                parent = sub;
408            }
409            wat => anyhow::bail!(
410                "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
411                wat.type_name()
412            ),
413        }
414    }
415
416    Ok(parent)
417}
418
419/// Helper for mapping back to lua errors
420pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
421    mlua::Error::external(format!("{err:#}"))
422}
423
424/// Provides implementations of __pairs, __index and __len metamethods
425/// for a type that is Serialize and UserData.
426/// Neither implementation is considered to be ideal, as we must
427/// first serialize the value into a json Value which is then either
428/// iterated over, or indexed to produce the appropriate result for
429/// the metamethod.
430pub fn impl_pairs_and_index<T, M>(methods: &mut M)
431where
432    T: UserData + Serialize,
433    M: UserDataMethods<T>,
434{
435    methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
436        let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
437            return Err(mlua::Error::external("must serialize to Map"));
438        };
439
440        let mut value_iter = map.into_iter();
441
442        let iter_func = lua.create_function_mut(
443            move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
444                Some((key, value)) => {
445                    let key = lua.to_value(&key)?;
446                    let value = lua.to_value(&value)?;
447                    Ok((key, value))
448                }
449                None => Ok((Value::Nil, Value::Nil)),
450            },
451        )?;
452
453        Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
454    });
455
456    methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
457        let value = lua.to_value(this)?;
458        match value {
459            Value::Table(t) => t.get(field),
460            _ => Ok(Value::Nil),
461        }
462    });
463
464    methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
465        let value = lua.to_value(this)?;
466        match value {
467            Value::Table(v) => v.len(),
468            Value::String(v) => Ok(v.as_bytes().len() as i64),
469            _ => Ok(0),
470        }
471    });
472}
473
474/// This function will try to obtain a native lua representation
475/// of the provided value. It does this by attempting to iterate
476/// the pairs of any userdata it finds as either the value itself
477/// or the values of a table value by recursively applying
478/// materialize_to_lua_value to the value.
479/// This produces a lua value that can then be processed by the
480/// Deserialize impl on Value.
481pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
482    match value {
483        mlua::Value::UserData(ud) => {
484            let mt = ud.metatable()?;
485            let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
486                let value = ud.into_lua(lua)?;
487                return Err(mlua::Error::external(format!(
488                    "cannot materialize_to_lua_value {value:?} \
489                     because it has no __pairs metamethod"
490                )));
491            };
492            let tbl = lua.create_table()?;
493            let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
494                pairs.call(mlua::Value::UserData(ud.clone()))?;
495
496            loop {
497                let (k, v): (mlua::Value, mlua::Value) =
498                    iter_func.call((state.clone(), control))?;
499                if k.is_nil() {
500                    break;
501                }
502
503                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
504                control = k;
505            }
506
507            Ok(mlua::Value::Table(tbl))
508        }
509        mlua::Value::Table(t) => {
510            let tbl = lua.create_table()?;
511            for pair in t.pairs::<mlua::Value, mlua::Value>() {
512                let (k, v) = pair?;
513                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
514            }
515            Ok(mlua::Value::Table(tbl))
516        }
517        value => Ok(value),
518    }
519}
520
521/// Helper wrapper type for passing/returning serde encoded values from/to lua
522pub struct SerdeWrappedValue<T>(pub T);
523
524impl<T: serde::Serialize> serde::Serialize for SerdeWrappedValue<T> {
525    fn serialize<S>(
526        &self,
527        s: S,
528    ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
529    where
530        S: serde::Serializer,
531    {
532        self.0.serialize(s)
533    }
534}
535
536impl<T: Clone> Clone for SerdeWrappedValue<T> {
537    fn clone(&self) -> Self {
538        SerdeWrappedValue(self.0.clone())
539    }
540}
541
542impl<T: Default> Default for SerdeWrappedValue<T> {
543    fn default() -> Self {
544        SerdeWrappedValue(Default::default())
545    }
546}
547
548impl<T: serde::Serialize> SerdeWrappedValue<T> {
549    pub fn to_lua_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
550        lua.to_value_with(&self.0, serialize_options())
551    }
552}
553
554impl<T: serde::Serialize> IntoLua for SerdeWrappedValue<T> {
555    fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
556        lua.to_value_with(&self.0, serialize_options())
557    }
558}
559
560impl<T: serde::de::DeserializeOwned> FromLua for SerdeWrappedValue<T> {
561    fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<SerdeWrappedValue<T>> {
562        let inner: T = from_lua_value(lua, value)?;
563        Ok(SerdeWrappedValue(inner))
564    }
565}
566
567impl<T> std::ops::Deref for SerdeWrappedValue<T> {
568    type Target = T;
569    fn deref(&self) -> &T {
570        &self.0
571    }
572}
573
574impl<T> std::ops::DerefMut for SerdeWrappedValue<T> {
575    fn deref_mut(&mut self) -> &mut T {
576        &mut self.0
577    }
578}
579
580/// Convert from a lua value to a deserializable type,
581/// with a slightly more helpful error message in case of failure.
582/// NOTE: the ", while processing" portion of the error messages generated
583/// here is coupled with a regex in typing.lua!
584pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
585where
586    R: serde::de::DeserializeOwned,
587{
588    let value_cloned = value.clone();
589    match lua.from_value(value) {
590        Ok(r) => Ok(r),
591        Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
592            Ok(materialized) => match lua.from_value(materialized.clone()) {
593                Ok(r) => Ok(r),
594                Err(err) => {
595                    let mut serializer = serde_json::Serializer::new(Vec::new());
596                    let serialized = match materialized.serialize(&mut serializer) {
597                        Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
598                        Err(err) => format!("<unable to encode as json: {err:#}>"),
599                    };
600                    Err(mlua::Error::external(format!(
601                        "{err:#}, while processing {serialized}"
602                    )))
603                }
604            },
605            Err(materialize_err) => Err(mlua::Error::external(format!(
606                "{err:#}, while processing a userdata. \
607                    Additionally, encountered {materialize_err:#} \
608                    when trying to iterate the pairs of that userdata"
609            ))),
610        },
611    }
612}
613
614/// CallbackSignature is a bit sugar to aid with statically typing event callback
615/// function invocation.
616///
617/// The idea is that you declare a signature instance that is typed
618/// with its argument tuple (A), and its return type tuple (R).
619///
620/// The signature instance can then be used to invoke the callback by name.
621///
622/// The register method allows pre-registering events so that `kumo.on`
623/// can reason about them better.  The main function enabled by this is
624/// `allow_multiple`; when that is set to true, `kumo.on` will allow
625/// recording multiple callback instances, calling them in sequence
626/// until one of them returns a value.
627pub struct CallbackSignature<A, R>
628where
629    A: IntoLuaMulti,
630    R: FromLuaMulti,
631{
632    marker: std::marker::PhantomData<(A, R)>,
633    allow_multiple: bool,
634    name: Cow<'static, str>,
635}
636
637#[linkme::distributed_slice]
638pub static CALLBACK_SIGNATURES: [fn()];
639
640/// Helper for declaring a named event handler callback signature.
641///
642/// Usage looks like:
643///
644/// ```rust,ignore
645/// declare_event! {
646/// pub static GET_Q_CONFIG_SIG: Multiple(
647///         "get_queue_config",
648///         domain: &'static str,
649///         tenant: Option<&'static str>,
650///         campaign: Option<&'static str>,
651///         routing_domain: Option<&'static str>,
652///     ) -> QueueConfig;
653/// }
654/// ```
655///
656/// A handler can be either `Single` or `Multiple`, indicating whether
657/// only a single registration or multiple registrations are permitted.
658/// The string literal is the name of the event, followed by a fn-style
659/// parameter list which names each parameter in sequence, followed by
660/// the return value.  The names are not currently used in any way,
661/// but enhance the readability of the code.
662///
663/// In addition to declaring the signature in a global, some glue
664/// is generated that will register the signature appropriately
665/// so that lua knows whether it is single or multiple and can
666/// act appropriately when `kumo.on` is called.
667#[macro_export]
668macro_rules! declare_event {
669    ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
670        $vis static $sym: ::std::sync::LazyLock<
671            $crate::CallbackSignature<($($args),*), $ret>> =
672                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
673
674        $crate::paste::paste! {
675            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
676            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
677                $sym.register();
678            };
679        }
680    };
681    ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
682        $vis static $sym: ::std::sync::LazyLock<
683            $crate::CallbackSignature<($($args),*), $ret>> =
684                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
685
686        $crate::paste::paste! {
687            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
688            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
689                $sym.register();
690            };
691        }
692    };
693}
694
695/// For each event handler CallbackSignature that was declared via
696/// `declare_event!`, call its `.register()` method to register
697/// it so that `kumo.on` can give appropriate messaging if misused,
698/// and so that runtime dispatch will work correctly.
699///
700/// This should be called once, prior to running any lua code.
701fn register_declared_events() {
702    static ONCE: Once = Once::new();
703    ONCE.call_once(|| {
704        for reg_func in CALLBACK_SIGNATURES {
705            reg_func();
706        }
707    });
708}
709
710impl<A, R> CallbackSignature<A, R>
711where
712    A: IntoLuaMulti,
713    R: FromLuaMulti,
714{
715    pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
716        let name = name.into();
717
718        Self {
719            marker: std::marker::PhantomData,
720            allow_multiple: false,
721            name,
722        }
723    }
724
725    /// Make sure that you call .register() on this from
726    /// eg: mod_kumo::register in order for it to be instantiated
727    /// and visible to the config loader
728    pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
729        let name = name.into();
730
731        Self {
732            marker: std::marker::PhantomData,
733            allow_multiple: true,
734            name,
735        }
736    }
737
738    pub fn register(&self) {
739        if self.allow_multiple {
740            CALLBACK_ALLOWS_MULTIPLE
741                .lock()
742                .insert(self.name.to_string());
743        }
744    }
745
746    pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
747        anyhow::ensure!(
748            !self.allow_multiple(),
749            "handler {} is set to allow multiple handlers \
750                    but is registered with a single instance. This indicates that \
751                    register() was not called on the signature when initializing \
752                    the lua context. Please report this issue to the KumoMTA team!",
753            self.name
754        );
755        Ok(())
756    }
757
758    /// Return true if this signature allows multiple instances to be registered
759    /// and called.
760    pub fn allow_multiple(&self) -> bool {
761        self.allow_multiple
762    }
763
764    pub fn name(&self) -> &str {
765        &self.name
766    }
767
768    pub fn decorated_name(&self) -> String {
769        decorate_callback_name(&self.name)
770    }
771}
772
773impl<A, R> CallbackSignature<A, R>
774where
775    A: IntoLuaMulti + Clone,
776    R: FromLuaMulti,
777{
778    /// Calls the callback, passing in the supplied arguments.
779    /// Returns the callback disposition which allows further
780    /// intrepretation of the result, such as deciding whether
781    /// to Default the value based on whether the callback
782    /// was defined or not
783    pub async fn call<'a>(
784        &'a self,
785        lua: &Lua,
786        args: A,
787    ) -> anyhow::Result<CallbackDisposition<'a, R>> {
788        let name = self.name();
789        let decorated_name = self.decorated_name();
790
791        match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
792            Value::Table(tbl) => {
793                for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
794                    let func = func?;
795                    let _timer = latency_timer(name);
796                    let result: mlua::MultiValue = func.call_async(args.clone()).await?;
797                    if result.is_empty() {
798                        // Continue with other handlers
799                        continue;
800                    }
801                    let result = R::from_lua_multi(result, lua)?;
802                    return Ok(CallbackDisposition {
803                        handler_was_defined: true,
804                        result: Some(result),
805                        event_name: name,
806                    });
807                }
808                Ok(CallbackDisposition {
809                    handler_was_defined: false,
810                    result: None,
811                    event_name: name,
812                })
813            }
814            Value::Function(func) => {
815                self.raise_error_if_allow_multiple()?;
816                let _timer = latency_timer(name);
817                Ok(CallbackDisposition {
818                    handler_was_defined: true,
819                    result: Some(func.call_async(args.clone()).await?),
820                    event_name: name,
821                })
822            }
823            _ => Ok(CallbackDisposition {
824                handler_was_defined: false,
825                result: None,
826                event_name: name,
827            }),
828        }
829    }
830}
831
832pub fn does_callback_allow_multiple(name: &str) -> bool {
833    CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
834}
835
836pub fn decorate_callback_name(name: &str) -> String {
837    format!("kumomta-on-{name}")
838}
839
840/// Allows reasoning about the result of a callback to decide
841/// how to interpret the result
842pub struct CallbackDisposition<'a, T> {
843    /// Indicates whether the handler was found. If false,
844    /// then result will also be None.
845    /// If true and result.is_none(), it indicates that
846    /// none of the handlers returned a value.
847    pub handler_was_defined: bool,
848    /// The result!
849    pub result: Option<T>,
850    /// Which event was called (or to be called)
851    pub event_name: &'a str,
852}
853
854impl<'a, T> CallbackDisposition<'a, T> {
855    /// Requires that a value be returned.  Translates the disposition
856    /// into appropriate error messaging if no value was returned.
857    pub fn require_value(mut self) -> anyhow::Result<T> {
858        if !self.handler_was_defined {
859            anyhow::bail!("Event {} has not been registered", self.event_name);
860        }
861        match self.result.take() {
862            Some(value) => Ok(value),
863            None => anyhow::bail!("Event {} did not return a value", self.event_name),
864        }
865    }
866}
867impl<'a, T: Default> CallbackDisposition<'a, T> {
868    /// Unwraps the value type, or the Default impl if no value was returned
869    pub fn or_default(self) -> T {
870        self.result.unwrap_or_default()
871    }
872}
873
874pub fn serialize_options() -> mlua::SerializeOptions {
875    mlua::SerializeOptions::new()
876        .serialize_none_to_null(false)
877        .serialize_unit_to_null(false)
878}