use crate::pool::{pool_get, pool_put};
pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
use anyhow::Context;
use mlua::{
FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
UserData, UserDataMethods, Value,
};
use parking_lot::FairMutex as Mutex;
use prometheus::{CounterVec, HistogramTimer, HistogramVec};
use serde::Serialize;
use std::borrow::Cow;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::LazyLock;
use std::time::Instant;
pub mod epoch;
mod pool;
static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
static LUA_LOAD_COUNT: LazyLock<metrics::Counter> = LazyLock::new(|| {
metrics::describe_counter!(
"lua_load_count",
"how many times the policy lua script has been \
loaded into a new context"
);
metrics::counter!("lua_load_count")
});
static LUA_COUNT: LazyLock<metrics::Gauge> = LazyLock::new(|| {
metrics::describe_gauge!("lua_count", "the number of lua contexts currently alive");
metrics::gauge!("lua_count")
});
static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
LazyLock::new(|| Mutex::new(HashSet::new()));
pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
static LATENCY_HIST: LazyLock<HistogramVec> = LazyLock::new(|| {
prometheus::register_histogram_vec!(
"lua_event_latency",
"how long a given lua event callback took",
&["event"]
)
.unwrap()
});
static EVENT_STARTED_COUNT: LazyLock<CounterVec> = LazyLock::new(|| {
prometheus::register_counter_vec!(
"lua_event_started",
"Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events",
&["event"]
)
.unwrap()
});
pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
fn latency_timer(label: &str) -> HistogramTimer {
EVENT_STARTED_COUNT
.get_metric_with_label_values(&[label])
.expect("to get counter")
.inc();
LATENCY_HIST
.get_metric_with_label_values(&[label])
.expect("to get histo")
.start_timer()
}
#[derive(Debug)]
struct LuaConfigInner {
lua: Lua,
created: Instant,
use_count: usize,
}
impl Drop for LuaConfigInner {
fn drop(&mut self) {
LUA_COUNT.decrement(1.);
}
}
#[derive(Debug)]
pub struct LuaConfig {
inner: Option<LuaConfigInner>,
}
impl Drop for LuaConfig {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
pool_put(inner);
}
}
}
pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
POLICY_FILE.lock().replace(path);
load_config().await?;
Ok(())
}
fn get_policy_path() -> Option<PathBuf> {
POLICY_FILE.lock().clone()
}
fn get_funcs() -> Vec<RegisterFunc> {
FUNCS.lock().clone()
}
pub fn is_validating() -> bool {
VALIDATE_ONLY.load(Ordering::Relaxed)
}
pub fn validation_failed() -> bool {
VALIDATION_FAILED.load(Ordering::Relaxed)
}
pub fn set_validation_failed() {
VALIDATION_FAILED.store(true, Ordering::Relaxed)
}
pub async fn load_config() -> anyhow::Result<LuaConfig> {
if let Some(pool) = pool_get() {
return Ok(pool);
}
LUA_LOAD_COUNT.increment(1);
let lua = Lua::new();
let created = Instant::now();
{
let globals = lua.globals();
if is_validating() {
globals.set("_VALIDATING_CONFIG", true)?;
}
let package: Table = globals.get("package")?;
let package_path: String = package.get("path")?;
let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
fn prefix_path(array: &mut Vec<String>, path: &str) {
array.insert(0, format!("{}/?.lua", path));
array.insert(1, format!("{}/?/init.lua", path));
}
prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
prefix_path(&mut path_array, "/opt/kumomta/share");
#[cfg(debug_assertions)]
prefix_path(&mut path_array, "assets");
package.set("path", path_array.join(";"))?;
}
for func in get_funcs() {
(func)(&lua)?;
}
if let Some(policy) = get_policy_path() {
let code = tokio::fs::read_to_string(&policy)
.await
.with_context(|| format!("reading policy file {policy:?}"))?;
let func = {
let chunk = lua.load(&code);
let chunk = chunk.set_name(policy.to_string_lossy());
chunk.into_function()?
};
let _timer = latency_timer("context-creation");
func.call_async::<()>(()).await?;
}
LUA_COUNT.increment(1.);
Ok(LuaConfig {
inner: Some(LuaConfigInner {
lua,
created,
use_count: 1,
}),
})
}
pub fn register(func: RegisterFunc) {
FUNCS.lock().push(func);
}
impl LuaConfig {
fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
self.inner
.as_mut()
.unwrap()
.lua
.globals()
.set("_KUMO_CURRENT_EVENT", name.to_string())
}
pub async fn convert_args_and_call_callback<A: Serialize>(
&mut self,
sig: &CallbackSignature<Value, ()>,
args: A,
) -> anyhow::Result<()> {
let lua = self.inner.as_mut().unwrap();
let args = lua.lua.to_value(&args)?;
let name = sig.name();
let decorated_name = sig.decorated_name();
match lua
.lua
.named_registry_value::<mlua::Function>(&decorated_name)
{
Ok(func) => {
let _timer = latency_timer(name);
Ok(func.call_async(args).await?)
}
_ => anyhow::bail!("{name} has not been registered"),
}
}
pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
&mut self,
sig: &CallbackSignature<A, R>,
args: A,
) -> anyhow::Result<R> {
let name = sig.name();
self.set_current_event(name)?;
let lua = self.inner.as_mut().unwrap();
async_call_callback(&lua.lua, sig, args).await
}
pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
&mut self,
sig: &CallbackSignature<A, R>,
args: A,
) -> anyhow::Result<R> {
let name = sig.name();
self.set_current_event(name)?;
let lua = self.inner.as_mut().unwrap();
async_call_callback_non_default(&lua.lua, sig, args).await
}
pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
&mut self,
sig: &CallbackSignature<A, Option<R>>,
args: A,
) -> anyhow::Result<Option<R>> {
let name = sig.name();
let decorated_name = sig.decorated_name();
self.set_current_event(name)?;
let lua = self.inner.as_mut().unwrap();
match lua
.lua
.named_registry_value::<mlua::Value>(&decorated_name)?
{
Value::Table(tbl) => {
for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
let func = func?;
let _timer = latency_timer(name);
let result: mlua::MultiValue = func.call_async(args.clone()).await?;
if result.is_empty() {
continue;
}
let result = R::from_lua_multi(result, &lua.lua)?;
return Ok(Some(result));
}
Ok(None)
}
Value::Function(func) => {
sig.raise_error_if_allow_multiple()?;
let _timer = latency_timer(name);
let value: Value = func.call_async(args.clone()).await?;
match value {
Value::Nil => Ok(None),
value => {
let result = R::from_lua(value, &lua.lua)?;
Ok(Some(result))
}
}
}
_ => Ok(None),
}
}
pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
Ok(self
.inner
.as_mut()
.unwrap()
.lua
.remove_registry_value(value)?)
}
pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
&mut self,
sig: &CallbackSignature<A, Value>,
args: A,
) -> anyhow::Result<RegistryKey> {
let name = sig.name();
anyhow::ensure!(
!sig.allow_multiple(),
"ctor event signature for {name} is defined as allow_multiple, which is not supported"
);
let decorated_name = sig.decorated_name();
self.set_current_event(name)?;
let inner = self.inner.as_mut().unwrap();
let func = inner
.lua
.named_registry_value::<mlua::Function>(&decorated_name)?;
let _timer = latency_timer(name);
let value: Value = func.call_async(args.clone()).await?;
drop(func);
Ok(inner.lua.create_registry_value(value)?)
}
pub async fn with_registry_value<F, R, FUT>(
&mut self,
value: &RegistryKey,
func: F,
) -> anyhow::Result<R>
where
R: FromLuaMulti,
F: FnOnce(Value) -> anyhow::Result<FUT>,
FUT: std::future::Future<Output = anyhow::Result<R>>,
{
let inner = self.inner.as_mut().unwrap();
let value = inner.lua.registry_value(value)?;
let future = (func)(value)?;
future.await
}
}
pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
lua: &Lua,
sig: &CallbackSignature<A, R>,
args: A,
) -> anyhow::Result<R> {
let name = sig.name();
let decorated_name = sig.decorated_name();
match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
Value::Table(tbl) => {
for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
let func = func?;
let _timer = latency_timer(name);
let result: mlua::MultiValue = func.call_async(args.clone()).await?;
if result.is_empty() {
continue;
}
let result = R::from_lua_multi(result, lua)?;
return Ok(result);
}
Ok(R::default())
}
Value::Function(func) => {
sig.raise_error_if_allow_multiple()?;
let _timer = latency_timer(name);
Ok(func.call_async(args.clone()).await?)
}
_ => Ok(R::default()),
}
}
pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
lua: &Lua,
sig: &CallbackSignature<A, R>,
args: A,
) -> anyhow::Result<R> {
let name = sig.name();
let decorated_name = sig.decorated_name();
match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
Value::Table(tbl) => {
for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
let func = func?;
let _timer = latency_timer(name);
let result: mlua::MultiValue = func.call_async(args.clone()).await?;
if result.is_empty() {
continue;
}
let result = R::from_lua_multi(result, lua)?;
return Ok(result);
}
anyhow::bail!("invalid return type for {name} event");
}
Value::Function(func) => {
sig.raise_error_if_allow_multiple()?;
let _timer = latency_timer(name);
Ok(func.call_async(args.clone()).await?)
}
_ => anyhow::bail!("Event {name} has not been registered"),
}
}
pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
let globals = lua.globals();
let package: Table = globals.get("package")?;
let loaded: Table = package.get("loaded")?;
let module = loaded.get(name)?;
match module {
Value::Nil => {
let module = lua.create_table()?;
loaded.set(name, module.clone())?;
Ok(module)
}
Value::Table(table) => Ok(table),
wat => anyhow::bail!(
"cannot register module {} as package.loaded.{} is already set to a value of type {}",
name,
name,
wat.type_name()
),
}
}
pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
let mut parent = get_or_create_module(lua, "kumo")?;
let mut path_so_far = String::new();
for name in name_path.split('.') {
if !path_so_far.is_empty() {
path_so_far.push('.');
}
path_so_far.push_str(name);
let sub = parent.get(name)?;
match sub {
Value::Nil => {
let sub = lua.create_table()?;
parent.set(name, sub.clone())?;
parent = sub;
}
Value::Table(sub) => {
parent = sub;
}
wat => anyhow::bail!(
"cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
wat.type_name()
),
}
}
Ok(parent)
}
pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
mlua::Error::external(format!("{err:#}"))
}
pub fn impl_pairs_and_index<T, M>(methods: &mut M)
where
T: UserData + Serialize,
M: UserDataMethods<T>,
{
methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
return Err(mlua::Error::external("must serialize to Map"));
};
let mut value_iter = map.into_iter();
let iter_func = lua.create_function_mut(
move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
Some((key, value)) => {
let key = lua.to_value(&key)?;
let value = lua.to_value(&value)?;
Ok((key, value))
}
None => Ok((Value::Nil, Value::Nil)),
},
)?;
Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
});
methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
let value = lua.to_value(this)?;
match value {
Value::Table(t) => t.get(field),
_ => Ok(Value::Nil),
}
});
}
pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
match value {
mlua::Value::UserData(ud) => {
let mt = ud.metatable()?;
let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
let value = ud.into_lua(lua)?;
return Err(mlua::Error::external(format!(
"cannot materialize_to_lua_value {value:?} \
because it has no __pairs metamethod"
)));
};
let tbl = lua.create_table()?;
let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
pairs.call(mlua::Value::UserData(ud.clone()))?;
loop {
let (k, v): (mlua::Value, mlua::Value) =
iter_func.call((state.clone(), control))?;
if k.is_nil() {
break;
}
tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
control = k;
}
Ok(mlua::Value::Table(tbl))
}
mlua::Value::Table(t) => {
let tbl = lua.create_table()?;
for pair in t.pairs::<mlua::Value, mlua::Value>() {
let (k, v) = pair?;
tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
}
Ok(mlua::Value::Table(tbl))
}
value => Ok(value),
}
}
pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
where
R: serde::de::DeserializeOwned,
{
let value_cloned = value.clone();
match lua.from_value(value) {
Ok(r) => Ok(r),
Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
Ok(materialized) => match lua.from_value(materialized.clone()) {
Ok(r) => Ok(r),
Err(err) => {
let mut serializer = serde_json::Serializer::new(Vec::new());
let serialized = match materialized.serialize(&mut serializer) {
Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
Err(err) => format!("<unable to encode as json: {err:#}>"),
};
Err(mlua::Error::external(format!(
"{err:#}, while processing {serialized}"
)))
}
},
Err(materialize_err) => Err(mlua::Error::external(format!(
"{err:#}, while processing a userdata. \
Additionally, encountered {materialize_err:#} \
when trying to iterate the pairs of that userdata"
))),
},
}
}
pub struct CallbackSignature<A, R>
where
A: IntoLuaMulti,
R: FromLuaMulti,
{
marker: std::marker::PhantomData<(A, R)>,
allow_multiple: bool,
name: Cow<'static, str>,
}
impl<A, R> CallbackSignature<A, R>
where
A: IntoLuaMulti,
R: FromLuaMulti,
{
pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
let name = name.into();
Self {
marker: std::marker::PhantomData,
allow_multiple: false,
name,
}
}
pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
let name = name.into();
Self {
marker: std::marker::PhantomData,
allow_multiple: true,
name,
}
}
pub fn register(&self) {
if self.allow_multiple {
CALLBACK_ALLOWS_MULTIPLE
.lock()
.insert(self.name.to_string());
}
}
pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
anyhow::ensure!(
!self.allow_multiple(),
"handler {} is set to allow multiple handlers \
but is registered with a single instance. This indicates that \
register() was not called on the signature when initializing \
the lua context. Please report this issue to the KumoMTA team!",
self.name
);
Ok(())
}
pub fn allow_multiple(&self) -> bool {
self.allow_multiple
}
pub fn name(&self) -> &str {
&self.name
}
pub fn decorated_name(&self) -> String {
decorate_callback_name(&self.name)
}
}
pub fn does_callback_allow_multiple(name: &str) -> bool {
CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
}
pub fn decorate_callback_name(name: &str) -> String {
format!("kumomta-on-{name}")
}
pub fn serialize_options() -> mlua::SerializeOptions {
mlua::SerializeOptions::new()
.serialize_none_to_null(false)
.serialize_unit_to_null(false)
}