config/
lib.rs

1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use mlua::{
6    FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
7    UserData, UserDataMethods, Value,
8};
9use parking_lot::FairMutex as Mutex;
10pub use paste;
11use prometheus::{CounterVec, HistogramTimer, HistogramVec};
12use serde::Serialize;
13use std::borrow::Cow;
14use std::collections::HashSet;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{LazyLock, Once};
18use std::time::Instant;
19
20pub mod epoch;
21mod pool;
22
23static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
24static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
25static LUA_LOAD_COUNT: LazyLock<metrics::Counter> = LazyLock::new(|| {
26    metrics::describe_counter!(
27        "lua_load_count",
28        "how many times the policy lua script has been \
29         loaded into a new context"
30    );
31    metrics::counter!("lua_load_count")
32});
33static LUA_COUNT: LazyLock<metrics::Gauge> = LazyLock::new(|| {
34    metrics::describe_gauge!("lua_count", "the number of lua contexts currently alive");
35    metrics::gauge!("lua_count")
36});
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38    LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42static LATENCY_HIST: LazyLock<HistogramVec> = LazyLock::new(|| {
43    prometheus::register_histogram_vec!(
44        "lua_event_latency",
45        "how long a given lua event callback took",
46        &["event"]
47    )
48    .unwrap()
49});
50static EVENT_STARTED_COUNT: LazyLock<CounterVec> = LazyLock::new(|| {
51    prometheus::register_counter_vec!(
52        "lua_event_started",
53        "Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events",
54        &["event"]
55    )
56    .unwrap()
57});
58
59pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
60
61fn latency_timer(label: &str) -> HistogramTimer {
62    EVENT_STARTED_COUNT
63        .get_metric_with_label_values(&[label])
64        .expect("to get counter")
65        .inc();
66    LATENCY_HIST
67        .get_metric_with_label_values(&[label])
68        .expect("to get histo")
69        .start_timer()
70}
71
72#[derive(Debug)]
73struct LuaConfigInner {
74    lua: Lua,
75    created: Instant,
76    use_count: usize,
77    epoch: ConfigEpoch,
78}
79
80impl Drop for LuaConfigInner {
81    fn drop(&mut self) {
82        LUA_COUNT.decrement(1.);
83    }
84}
85
86#[derive(Debug)]
87pub struct LuaConfig {
88    inner: Option<LuaConfigInner>,
89}
90
91pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
92    POLICY_FILE.lock().replace(path);
93    let config = load_config().await?;
94    config.put();
95    Ok(())
96}
97
98fn get_policy_path() -> Option<PathBuf> {
99    POLICY_FILE.lock().clone()
100}
101
102fn get_funcs() -> Vec<RegisterFunc> {
103    FUNCS.lock().clone()
104}
105pub fn is_validating() -> bool {
106    VALIDATE_ONLY.load(Ordering::Relaxed)
107}
108
109pub fn validation_failed() -> bool {
110    VALIDATION_FAILED.load(Ordering::Relaxed)
111}
112
113pub fn set_validation_failed() {
114    VALIDATION_FAILED.store(true, Ordering::Relaxed)
115}
116
117pub async fn load_config() -> anyhow::Result<LuaConfig> {
118    if let Some(pool) = pool_get() {
119        return Ok(pool);
120    }
121
122    LUA_LOAD_COUNT.increment(1);
123    let lua = Lua::new();
124    let created = Instant::now();
125    let epoch = get_current_epoch();
126
127    {
128        let globals = lua.globals();
129
130        if is_validating() {
131            globals.set("_VALIDATING_CONFIG", true)?;
132        }
133
134        let package: Table = globals.get("package")?;
135        let package_path: String = package.get("path")?;
136        let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
137
138        fn prefix_path(array: &mut Vec<String>, path: &str) {
139            array.insert(0, format!("{}/?.lua", path));
140            array.insert(1, format!("{}/?/init.lua", path));
141        }
142
143        prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
144        prefix_path(&mut path_array, "/opt/kumomta/share");
145
146        #[cfg(debug_assertions)]
147        prefix_path(&mut path_array, "assets");
148
149        package.set("path", path_array.join(";"))?;
150    }
151
152    register_declared_events();
153
154    for func in get_funcs() {
155        (func)(&lua)?;
156    }
157
158    if let Some(policy) = get_policy_path() {
159        let code = tokio::fs::read_to_string(&policy)
160            .await
161            .with_context(|| format!("reading policy file {policy:?}"))?;
162
163        let func = {
164            let chunk = lua.load(&code);
165            let chunk = chunk.set_name(policy.to_string_lossy());
166            chunk.into_function()?
167        };
168
169        let _timer = latency_timer("context-creation");
170        func.call_async::<()>(()).await?;
171    }
172    LUA_COUNT.increment(1.);
173
174    Ok(LuaConfig {
175        inner: Some(LuaConfigInner {
176            lua,
177            created,
178            use_count: 1,
179            epoch,
180        }),
181    })
182}
183
184pub fn register(func: RegisterFunc) {
185    FUNCS.lock().push(func);
186}
187
188impl LuaConfig {
189    fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
190        self.inner
191            .as_mut()
192            .unwrap()
193            .lua
194            .globals()
195            .set("_KUMO_CURRENT_EVENT", name.to_string())
196    }
197
198    /// Intended to be used together with kumo.spawn_task
199    pub async fn convert_args_and_call_callback<A: Serialize>(
200        &mut self,
201        sig: &CallbackSignature<Value, ()>,
202        args: A,
203    ) -> anyhow::Result<()> {
204        let lua = self.inner.as_mut().unwrap();
205        let args = lua.lua.to_value(&args)?;
206
207        let name = sig.name();
208        let decorated_name = sig.decorated_name();
209
210        match lua
211            .lua
212            .named_registry_value::<mlua::Function>(&decorated_name)
213        {
214            Ok(func) => {
215                let _timer = latency_timer(name);
216                Ok(func.call_async(args).await?)
217            }
218            _ => anyhow::bail!("{name} has not been registered"),
219        }
220    }
221
222    /// Explicitly put the config object back into its containing pool.
223    /// Ideally we'd do this automatically when the object is dropped,
224    /// but lua's garbage collection makes this problematic:
225    /// if a future whose graph contains an async lua call within
226    /// this config object is cancelled (eg: simply stopped without
227    /// calling it again), and the config object is not explicitly garbage
228    /// collected, any futures and data owned by any dependencies of
229    /// the cancelled future remain alive until the next gc run,
230    /// which can cause things like async locks and semaphores to
231    /// have a lifetime extended by the maximum age of the lua context.
232    ///
233    /// The combat this, consumers of LuaConfig should explicitly
234    /// call `config.put()` after successfully using the config
235    /// object.
236    ///
237    /// Or framing it another way: consumers must not call `config.put()`
238    /// if a transitive dep might have been cancelled.
239    pub fn put(mut self) {
240        if let Some(inner) = self.inner.take() {
241            pool_put(inner);
242        }
243    }
244
245    pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
246        &mut self,
247        sig: &CallbackSignature<A, R>,
248        args: A,
249    ) -> anyhow::Result<R> {
250        let name = sig.name();
251        self.set_current_event(name)?;
252        let lua = self.inner.as_mut().unwrap();
253        async_call_callback(&lua.lua, sig, args).await
254    }
255
256    pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
257        &mut self,
258        sig: &CallbackSignature<A, R>,
259        args: A,
260    ) -> anyhow::Result<R> {
261        let name = sig.name();
262        self.set_current_event(name)?;
263        let lua = self.inner.as_mut().unwrap();
264        async_call_callback_non_default(&lua.lua, sig, args).await
265    }
266
267    pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
268        &mut self,
269        sig: &CallbackSignature<A, Option<R>>,
270        args: A,
271    ) -> anyhow::Result<Option<R>> {
272        let name = sig.name();
273        let decorated_name = sig.decorated_name();
274        self.set_current_event(name)?;
275        let lua = self.inner.as_mut().unwrap();
276
277        match lua
278            .lua
279            .named_registry_value::<mlua::Value>(&decorated_name)?
280        {
281            Value::Table(tbl) => {
282                for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
283                    let func = func?;
284                    let _timer = latency_timer(name);
285                    let result: mlua::MultiValue = func.call_async(args.clone()).await?;
286                    if result.is_empty() {
287                        // Continue with other handlers
288                        continue;
289                    }
290                    let result = R::from_lua_multi(result, &lua.lua)?;
291                    return Ok(Some(result));
292                }
293                Ok(None)
294            }
295            Value::Function(func) => {
296                sig.raise_error_if_allow_multiple()?;
297                let _timer = latency_timer(name);
298                let value: Value = func.call_async(args.clone()).await?;
299
300                match value {
301                    Value::Nil => Ok(None),
302                    value => {
303                        let result = R::from_lua(value, &lua.lua)?;
304                        Ok(Some(result))
305                    }
306                }
307            }
308            _ => Ok(None),
309        }
310    }
311
312    pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
313        Ok(self
314            .inner
315            .as_mut()
316            .unwrap()
317            .lua
318            .remove_registry_value(value)?)
319    }
320
321    /// Call a constructor registered via `on`. Returns a registry key that can be
322    /// used to reference the returned value again later on this same Lua instance
323    pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
324        &mut self,
325        sig: &CallbackSignature<A, Value>,
326        args: A,
327    ) -> anyhow::Result<RegistryKey> {
328        let name = sig.name();
329        anyhow::ensure!(
330            !sig.allow_multiple(),
331            "ctor event signature for {name} is defined as allow_multiple, which is not supported"
332        );
333
334        let decorated_name = sig.decorated_name();
335        self.set_current_event(name)?;
336
337        let inner = self.inner.as_mut().unwrap();
338
339        let func = inner
340            .lua
341            .named_registry_value::<mlua::Function>(&decorated_name)?;
342
343        let _timer = latency_timer(name);
344        let value: Value = func.call_async(args.clone()).await?;
345        drop(func);
346
347        Ok(inner.lua.create_registry_value(value)?)
348    }
349
350    /// Operate on an object/value that was previously constructed via
351    /// async_call_ctor.
352    pub async fn with_registry_value<F, R, FUT>(
353        &mut self,
354        value: &RegistryKey,
355        func: F,
356    ) -> anyhow::Result<R>
357    where
358        R: FromLuaMulti,
359        F: FnOnce(Value) -> anyhow::Result<FUT>,
360        FUT: std::future::Future<Output = anyhow::Result<R>>,
361    {
362        let inner = self.inner.as_mut().unwrap();
363        let value = inner.lua.registry_value(value)?;
364        let future = (func)(value)?;
365        future.await
366    }
367}
368
369pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
370    lua: &Lua,
371    sig: &CallbackSignature<A, R>,
372    args: A,
373) -> anyhow::Result<R> {
374    let name = sig.name();
375    let decorated_name = sig.decorated_name();
376
377    match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
378        Value::Table(tbl) => {
379            for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
380                let func = func?;
381                let _timer = latency_timer(name);
382                let result: mlua::MultiValue = func.call_async(args.clone()).await?;
383                if result.is_empty() {
384                    // Continue with other handlers
385                    continue;
386                }
387                let result = R::from_lua_multi(result, lua)?;
388                return Ok(result);
389            }
390            Ok(R::default())
391        }
392        Value::Function(func) => {
393            sig.raise_error_if_allow_multiple()?;
394            let _timer = latency_timer(name);
395            Ok(func.call_async(args.clone()).await?)
396        }
397        _ => Ok(R::default()),
398    }
399}
400
401pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
402    lua: &Lua,
403    sig: &CallbackSignature<A, R>,
404    args: A,
405) -> anyhow::Result<R> {
406    let name = sig.name();
407    let decorated_name = sig.decorated_name();
408
409    match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
410        Value::Table(tbl) => {
411            for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
412                let func = func?;
413                let _timer = latency_timer(name);
414                let result: mlua::MultiValue = func.call_async(args.clone()).await?;
415                if result.is_empty() {
416                    // Continue with other handlers
417                    continue;
418                }
419                let result = R::from_lua_multi(result, lua)?;
420                return Ok(result);
421            }
422            anyhow::bail!("invalid return type for {name} event");
423        }
424        Value::Function(func) => {
425            sig.raise_error_if_allow_multiple()?;
426            let _timer = latency_timer(name);
427            Ok(func.call_async(args.clone()).await?)
428        }
429        _ => anyhow::bail!("Event {name} has not been registered"),
430    }
431}
432
433pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
434    let globals = lua.globals();
435    let package: Table = globals.get("package")?;
436    let loaded: Table = package.get("loaded")?;
437
438    let module = loaded.get(name)?;
439    match module {
440        Value::Nil => {
441            let module = lua.create_table()?;
442            loaded.set(name, module.clone())?;
443            Ok(module)
444        }
445        Value::Table(table) => Ok(table),
446        wat => anyhow::bail!(
447            "cannot register module {} as package.loaded.{} is already set to a value of type {}",
448            name,
449            name,
450            wat.type_name()
451        ),
452    }
453}
454
455/// Given a name path like `foo` or `foo.bar.baz`, sets up the module
456/// registry hierarchy to instantiate that path.
457/// Returns the leaf node of that path to allow the caller to
458/// register/assign functions etc. into it
459pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
460    let mut parent = get_or_create_module(lua, "kumo")?;
461    let mut path_so_far = String::new();
462
463    for name in name_path.split('.') {
464        if !path_so_far.is_empty() {
465            path_so_far.push('.');
466        }
467        path_so_far.push_str(name);
468
469        let sub = parent.get(name)?;
470        match sub {
471            Value::Nil => {
472                let sub = lua.create_table()?;
473                parent.set(name, sub.clone())?;
474                parent = sub;
475            }
476            Value::Table(sub) => {
477                parent = sub;
478            }
479            wat => anyhow::bail!(
480                "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
481                wat.type_name()
482            ),
483        }
484    }
485
486    Ok(parent)
487}
488
489/// Helper for mapping back to lua errors
490pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
491    mlua::Error::external(format!("{err:#}"))
492}
493
494/// Provides implementations of __pairs, __index and __len metamethods
495/// for a type that is Serialize and UserData.
496/// Neither implementation is considered to be ideal, as we must
497/// first serialize the value into a json Value which is then either
498/// iterated over, or indexed to produce the appropriate result for
499/// the metamethod.
500pub fn impl_pairs_and_index<T, M>(methods: &mut M)
501where
502    T: UserData + Serialize,
503    M: UserDataMethods<T>,
504{
505    methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
506        let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
507            return Err(mlua::Error::external("must serialize to Map"));
508        };
509
510        let mut value_iter = map.into_iter();
511
512        let iter_func = lua.create_function_mut(
513            move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
514                Some((key, value)) => {
515                    let key = lua.to_value(&key)?;
516                    let value = lua.to_value(&value)?;
517                    Ok((key, value))
518                }
519                None => Ok((Value::Nil, Value::Nil)),
520            },
521        )?;
522
523        Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
524    });
525
526    methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
527        let value = lua.to_value(this)?;
528        match value {
529            Value::Table(t) => t.get(field),
530            _ => Ok(Value::Nil),
531        }
532    });
533
534    methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
535        let value = lua.to_value(this)?;
536        match value {
537            Value::Table(v) => v.len(),
538            Value::String(v) => Ok(v.as_bytes().len() as i64),
539            _ => Ok(0),
540        }
541    });
542}
543
544/// This function will try to obtain a native lua representation
545/// of the provided value. It does this by attempting to iterate
546/// the pairs of any userdata it finds as either the value itself
547/// or the values of a table value by recursively applying
548/// materialize_to_lua_value to the value.
549/// This produces a lua value that can then be processed by the
550/// Deserialize impl on Value.
551pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
552    match value {
553        mlua::Value::UserData(ud) => {
554            let mt = ud.metatable()?;
555            let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
556                let value = ud.into_lua(lua)?;
557                return Err(mlua::Error::external(format!(
558                    "cannot materialize_to_lua_value {value:?} \
559                     because it has no __pairs metamethod"
560                )));
561            };
562            let tbl = lua.create_table()?;
563            let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
564                pairs.call(mlua::Value::UserData(ud.clone()))?;
565
566            loop {
567                let (k, v): (mlua::Value, mlua::Value) =
568                    iter_func.call((state.clone(), control))?;
569                if k.is_nil() {
570                    break;
571                }
572
573                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
574                control = k;
575            }
576
577            Ok(mlua::Value::Table(tbl))
578        }
579        mlua::Value::Table(t) => {
580            let tbl = lua.create_table()?;
581            for pair in t.pairs::<mlua::Value, mlua::Value>() {
582                let (k, v) = pair?;
583                tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
584            }
585            Ok(mlua::Value::Table(tbl))
586        }
587        value => Ok(value),
588    }
589}
590
591/// Helper wrapper type for passing/returning serde encoded values from/to lua
592pub struct SerdeWrappedValue<T>(pub T);
593
594impl<T: serde::Serialize> SerdeWrappedValue<T> {
595    pub fn to_lua_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
596        lua.to_value_with(&self.0, serialize_options())
597    }
598}
599
600impl<T: serde::Serialize> IntoLua for SerdeWrappedValue<T> {
601    fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
602        lua.to_value_with(&self.0, serialize_options())
603    }
604}
605
606impl<T: serde::de::DeserializeOwned> FromLua for SerdeWrappedValue<T> {
607    fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<SerdeWrappedValue<T>> {
608        let inner: T = from_lua_value(lua, value)?;
609        Ok(SerdeWrappedValue(inner))
610    }
611}
612
613impl<T> std::ops::Deref for SerdeWrappedValue<T> {
614    type Target = T;
615    fn deref(&self) -> &T {
616        &self.0
617    }
618}
619
620impl<T> std::ops::DerefMut for SerdeWrappedValue<T> {
621    fn deref_mut(&mut self) -> &mut T {
622        &mut self.0
623    }
624}
625
626/// Convert from a lua value to a deserializable type,
627/// with a slightly more helpful error message in case of failure.
628/// NOTE: the ", while processing" portion of the error messages generated
629/// here is coupled with a regex in typing.lua!
630pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
631where
632    R: serde::de::DeserializeOwned,
633{
634    let value_cloned = value.clone();
635    match lua.from_value(value) {
636        Ok(r) => Ok(r),
637        Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
638            Ok(materialized) => match lua.from_value(materialized.clone()) {
639                Ok(r) => Ok(r),
640                Err(err) => {
641                    let mut serializer = serde_json::Serializer::new(Vec::new());
642                    let serialized = match materialized.serialize(&mut serializer) {
643                        Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
644                        Err(err) => format!("<unable to encode as json: {err:#}>"),
645                    };
646                    Err(mlua::Error::external(format!(
647                        "{err:#}, while processing {serialized}"
648                    )))
649                }
650            },
651            Err(materialize_err) => Err(mlua::Error::external(format!(
652                "{err:#}, while processing a userdata. \
653                    Additionally, encountered {materialize_err:#} \
654                    when trying to iterate the pairs of that userdata"
655            ))),
656        },
657    }
658}
659
660/// CallbackSignature is a bit sugar to aid with statically typing event callback
661/// function invocation.
662///
663/// The idea is that you declare a signature instance that is typed
664/// with its argument tuple (A), and its return type tuple (R).
665///
666/// The signature instance can then be used to invoke the callback by name.
667///
668/// The register method allows pre-registering events so that `kumo.on`
669/// can reason about them better.  The main function enabled by this is
670/// `allow_multiple`; when that is set to true, `kumo.on` will allow
671/// recording multiple callback instances, calling them in sequence
672/// until one of them returns a value.
673pub struct CallbackSignature<A, R>
674where
675    A: IntoLuaMulti,
676    R: FromLuaMulti,
677{
678    marker: std::marker::PhantomData<(A, R)>,
679    allow_multiple: bool,
680    name: Cow<'static, str>,
681}
682
683#[linkme::distributed_slice]
684pub static CALLBACK_SIGNATURES: [fn()];
685
686/// Helper for declaring a named event handler callback signature.
687///
688/// Usage looks like:
689///
690/// ```rust,ignore
691/// declare_event! {
692/// pub static GET_Q_CONFIG_SIG: Multiple(
693///         "get_queue_config",
694///         domain: &'static str,
695///         tenant: Option<&'static str>,
696///         campaign: Option<&'static str>,
697///         routing_domain: Option<&'static str>,
698///     ) -> QueueConfig;
699/// }
700/// ```
701///
702/// A handler can be either `Single` or `Multiple`, indicating whether
703/// only a single registration or multiple registrations are permitted.
704/// The string literal is the name of the event, followed by a fn-style
705/// parameter list which names each parameter in sequence, followed by
706/// the return value.  The names are not currently used in any way,
707/// but enhance the readability of the code.
708///
709/// In addition to declaring the signature in a global, some glue
710/// is generated that will register the signature appropriately
711/// so that lua knows whether it is single or multiple and can
712/// act appropriately when `kumo.on` is called.
713#[macro_export]
714macro_rules! declare_event {
715    ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
716        $vis static $sym: ::std::sync::LazyLock<
717            $crate::CallbackSignature<($($args),*), $ret>> =
718                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
719
720        $crate::paste::paste! {
721            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
722            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
723                $sym.register();
724            };
725        }
726    };
727    ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
728        $vis static $sym: ::std::sync::LazyLock<
729            $crate::CallbackSignature<($($args),*), $ret>> =
730                ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
731
732        $crate::paste::paste! {
733            #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
734            static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
735                $sym.register();
736            };
737        }
738    };
739}
740
741/// For each event handler CallbackSignature that was declared via
742/// `declare_event!`, call its `.register()` method to register
743/// it so that `kumo.on` can give appropriate messaging if misused,
744/// and so that runtime dispatch will work correctly.
745///
746/// This should be called once, prior to running any lua code.
747fn register_declared_events() {
748    static ONCE: Once = Once::new();
749    ONCE.call_once(|| {
750        for reg_func in CALLBACK_SIGNATURES {
751            reg_func();
752        }
753    });
754}
755
756impl<A, R> CallbackSignature<A, R>
757where
758    A: IntoLuaMulti,
759    R: FromLuaMulti,
760{
761    pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
762        let name = name.into();
763
764        Self {
765            marker: std::marker::PhantomData,
766            allow_multiple: false,
767            name,
768        }
769    }
770
771    /// Make sure that you call .register() on this from
772    /// eg: mod_kumo::register in order for it to be instantiated
773    /// and visible to the config loader
774    pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
775        let name = name.into();
776
777        Self {
778            marker: std::marker::PhantomData,
779            allow_multiple: true,
780            name,
781        }
782    }
783
784    pub fn register(&self) {
785        if self.allow_multiple {
786            CALLBACK_ALLOWS_MULTIPLE
787                .lock()
788                .insert(self.name.to_string());
789        }
790    }
791
792    pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
793        anyhow::ensure!(
794            !self.allow_multiple(),
795            "handler {} is set to allow multiple handlers \
796                    but is registered with a single instance. This indicates that \
797                    register() was not called on the signature when initializing \
798                    the lua context. Please report this issue to the KumoMTA team!",
799            self.name
800        );
801        Ok(())
802    }
803
804    /// Return true if this signature allows multiple instances to be registered
805    /// and called.
806    pub fn allow_multiple(&self) -> bool {
807        self.allow_multiple
808    }
809
810    pub fn name(&self) -> &str {
811        &self.name
812    }
813
814    pub fn decorated_name(&self) -> String {
815        decorate_callback_name(&self.name)
816    }
817}
818
819pub fn does_callback_allow_multiple(name: &str) -> bool {
820    CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
821}
822
823pub fn decorate_callback_name(name: &str) -> String {
824    format!("kumomta-on-{name}")
825}
826
827pub fn serialize_options() -> mlua::SerializeOptions {
828    mlua::SerializeOptions::new()
829        .serialize_none_to_null(false)
830        .serialize_unit_to_null(false)
831}