1use config::{any_err, get_or_create_sub_module};
2use mlua::prelude::LuaUserData;
3use mlua::{FromLua, Lua, MetaMethod, UserDataMethods};
4use mod_memoize::CacheValue;
5use regex::{RegexSet, RegexSetBuilder};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Serialize, Deserialize)]
10#[serde(try_from = "RegexSetMapBuilder<V>", into = "RegexSetMapBuilder<V>")]
11pub struct RegexSetMap<V: Clone> {
12 set: RegexSet,
13 pattern_to_value: Vec<V>,
14}
15
16impl<V: Clone> RegexSetMap<V> {
17 pub fn lookup(&self, subject: &str) -> Option<&V> {
18 self.set
19 .matches(subject)
20 .into_iter()
21 .next()
22 .and_then(|idx| self.pattern_to_value.get(idx))
23 }
24}
25
26impl<V: Clone> From<RegexSetMap<V>> for RegexSetMapBuilder<V> {
27 fn from(map: RegexSetMap<V>) -> Self {
28 let patterns = map.set.patterns().to_vec();
29 RegexSetMapBuilder {
30 patterns,
31 pattern_to_value: map.pattern_to_value,
32 }
33 }
34}
35
36impl<V: Clone> TryFrom<RegexSetMapBuilder<V>> for RegexSetMap<V> {
37 type Error = String;
38
39 fn try_from(builder: RegexSetMapBuilder<V>) -> Result<RegexSetMap<V>, String> {
40 builder.build()
41 }
42}
43
44#[derive(Serialize, Deserialize)]
45pub struct RegexSetMapBuilder<V: Clone> {
46 patterns: Vec<String>,
47 pattern_to_value: Vec<V>,
48}
49
50impl<V: Clone> Default for RegexSetMapBuilder<V> {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl<V: Clone> RegexSetMapBuilder<V> {
57 pub fn new() -> Self {
58 Self {
59 patterns: vec![],
60 pattern_to_value: vec![],
61 }
62 }
63
64 pub fn add_rule<S: Into<String>>(&mut self, rule: S, value: V) {
65 self.patterns.push(rule.into());
66 self.pattern_to_value.push(value);
67 }
68
69 pub fn build(mut self) -> Result<RegexSetMap<V>, String> {
70 self.patterns.shrink_to_fit();
71 self.pattern_to_value.shrink_to_fit();
72
73 let set = RegexSetBuilder::new(self.patterns)
74 .build()
75 .map_err(|err| format!("compiling rules: {err:#}"))?;
76 Ok(RegexSetMap {
77 set,
78 pattern_to_value: self.pattern_to_value,
79 })
80 }
81}
82
83impl LuaUserData for RegexSetMap<CacheValue> {
84 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
85 mod_memoize::Memoized::impl_memoize(methods);
86 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
87 if let Some(value) = this.lookup(&key) {
88 let value = value.as_lua(lua)?;
89 Ok(Some(value))
90 } else {
91 Ok(None)
92 }
93 });
94 }
95}
96
97pub fn register(lua: &Lua) -> anyhow::Result<()> {
98 let module = get_or_create_sub_module(lua, "regex_set_map")?;
99
100 module.set(
101 "new",
102 lua.create_function(|lua, value: Option<HashMap<String, mlua::Value>>| {
103 let mut builder: RegexSetMapBuilder<CacheValue> = RegexSetMapBuilder::new();
104
105 if let Some(value) = value {
106 for (k, v) in value {
107 let v = CacheValue::from_lua(v, lua)?;
108 builder.add_rule(&k, v);
109 }
110 }
111
112 builder.build().map_err(any_err)
113 })?,
114 )?;
115
116 Ok(())
117}
118
119#[cfg(test)]
120mod test {
121 use super::*;
122
123 #[test]
124 fn test_basic_mapping() {
125 let mut builder = RegexSetMapBuilder::new();
126 builder.add_rule(
127 "4\\.2\\.2 The email account that you tried to reach is over quota\\.",
128 500,
129 );
130 builder.add_rule("4\\.2\\.1 <.+>: Recipient address rejected: this mailbox is inactive and has been disabled", 501);
131 builder.add_rule("4\\.1\\.1 <.*> 4.2.2 mailbox full\\.", 502);
132 let mapper = builder.build().unwrap();
133
134 let corpus = &[
135 ("400 4.2.2 The email account that you tried to reach is over quota", None),
136 ("400 4.2.2 The email account that you tried to reach is over quota.", Some(500)),
137 ("400 4.2.1 <foo>: Recipient address rejected: this mailbox is inactive and has been disabled", Some(501)),
138 ("400 4.1.1 <bar> 4.2.2 mailbox full.", Some(502)),
139 ];
140
141 for &(input, output) in corpus {
142 assert_eq!(
143 mapper.lookup(input),
144 output.as_ref(),
145 "expected {input} -> {output:?}"
146 );
147 }
148 }
149}