1use config::get_or_create_sub_module;
5use mlua::prelude::LuaUserData;
6use mlua::{FromLua, Lua, MetaMethod, UserDataMethods};
7use mod_memoize::CacheValue;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap};
10use std::fmt::Debug;
11
12#[derive(Clone)]
13struct Node<V: Clone> {
14 value: Option<V>,
15 label: String,
16 children: HashMap<String, Self>,
17}
18
19impl<V: Debug + Clone> Debug for Node<V> {
20 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
21 fmt.debug_struct("Node")
22 .field("value", &self.value)
23 .field("label", &self.label)
24 .field("children", &self.children)
25 .finish()
26 }
27}
28
29#[derive(Default, Clone, Serialize, Deserialize)]
37#[serde(from = "BTreeMap<String, V>", into = "BTreeMap<String,V>")]
38pub struct DomainMap<V: Clone> {
39 top: HashMap<String, Node<V>>,
40}
41
42impl LuaUserData for DomainMap<CacheValue> {
43 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
44 mod_memoize::Memoized::impl_memoize(methods);
45 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
46 if let Some(value) = this.get(&key) {
47 let value = value.as_lua(lua)?;
48 Ok(Some(value))
49 } else {
50 Ok(None)
51 }
52 });
53 methods.add_meta_method_mut(
54 MetaMethod::NewIndex,
55 |lua, this, (key, value): (String, mlua::Value)| {
56 let value = CacheValue::from_lua(value, lua)?;
57 this.insert(&key, value);
58 Ok(())
59 },
60 );
61 }
62}
63
64impl<V: Debug + Clone> Debug for DomainMap<V> {
65 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
66 fmt.debug_struct("DomainMap")
67 .field("top", &self.top)
68 .finish()
69 }
70}
71
72impl<V: Clone> DomainMap<V> {
73 pub fn new() -> Self {
74 Self {
75 top: HashMap::new(),
76 }
77 }
78
79 pub fn insert(&mut self, pattern: &str, value: V) {
80 let mut current = &mut self.top;
81 let mut iter = pattern.rsplit('.').peekable();
82 while let Some(seg) = iter.next() {
83 let node = current.entry(seg.to_string()).or_insert_with(|| Node {
84 value: None,
85 label: seg.to_string(),
86 children: HashMap::new(),
87 });
88
89 if iter.peek().is_none() {
90 node.value.replace(value);
92 return;
93 }
94 current = &mut node.children;
95 }
96 }
97
98 pub fn get(&self, pattern: &str) -> Option<&V> {
99 let mut current = &self.top;
100 let mut iter = pattern.rsplit('.').peekable();
101 while let Some(seg) = iter.next() {
102 match current.get(seg) {
103 Some(node) => {
104 if iter.peek().is_none() {
105 return node.value.as_ref();
107 }
108 current = &node.children;
109 continue;
110 }
111 None => {
112 let wild = current.get("*")?;
114 return wild.value.as_ref();
115 }
116 }
117 }
118 None
119 }
120}
121
122impl<V: Clone> From<BTreeMap<String, V>> for DomainMap<V> {
123 fn from(map: BTreeMap<String, V>) -> Self {
124 let mut result = DomainMap::new();
125 for (k, v) in map {
126 result.insert(&k, v);
127 }
128 result
129 }
130}
131
132fn walk<'a, V: Clone>(
133 nodes: &'a HashMap<String, Node<V>>,
134 stack: &mut Vec<&'a str>,
135 result: &mut BTreeMap<String, V>,
136) {
137 for (key, value) in nodes {
138 stack.insert(0, key);
139 if let Some(v) = &value.value {
140 result.insert(stack.join("."), v.clone());
141 }
142 walk(&value.children, stack, result);
143 stack.remove(0);
144 }
145}
146
147impl<V: Clone> From<DomainMap<V>> for BTreeMap<String, V> {
148 fn from(map: DomainMap<V>) -> Self {
149 let mut result = BTreeMap::new();
150 let mut stack = vec![];
151 walk(&map.top, &mut stack, &mut result);
152
153 result
154 }
155}
156
157pub fn register(lua: &Lua) -> anyhow::Result<()> {
158 let dmap_mod = get_or_create_sub_module(lua, "domain_map")?;
159
160 dmap_mod.set(
161 "new",
162 lua.create_function(|lua, value: Option<HashMap<String, mlua::Value>>| {
163 let mut dmap: DomainMap<mod_memoize::CacheValue> = DomainMap::new();
164
165 if let Some(value) = value {
166 for (k, v) in value {
167 let v = CacheValue::from_lua(v, lua)?;
168 dmap.insert(&k, v);
169 }
170 }
171
172 Ok(dmap)
173 })?,
174 )?;
175
176 Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn basic() {
185 let mut map: DomainMap<u32> = DomainMap::new();
186 map.insert("*.example.com", 42);
187 map.insert("example.com", 24);
188 map.insert("omg.wtf.woot.example.com", 128);
189 println!("{map:#?}");
190
191 assert_eq!(map.get("foo.com"), None);
192 assert_eq!(map.get("example.com"), Some(&24));
193 assert_eq!(map.get("lemon.example.com"), Some(&42));
194 assert_eq!(map.get("lemon.cake.example.com"), Some(&42));
195 assert_eq!(map.get("woot.example.com"), None);
196 assert_eq!(map.get("wtf.woot.example.com"), None);
197 assert_eq!(map.get("omg.wtf.woot.example.com"), Some(&128));
198
199 let serialized: BTreeMap<_, _> = map.into();
200 k9::snapshot!(
201 &serialized,
202 r#"
203{
204 "*.example.com": 42,
205 "example.com": 24,
206 "omg.wtf.woot.example.com": 128,
207}
208"#
209 );
210
211 let round_trip: DomainMap<_> = serialized.into();
212 assert_eq!(round_trip.get("lemon.example.com"), Some(&42));
213
214 let serialized_again: BTreeMap<_, _> = round_trip.into();
215 k9::snapshot!(
216 &serialized_again,
217 r#"
218{
219 "*.example.com": 42,
220 "example.com": 24,
221 "omg.wtf.woot.example.com": 128,
222}
223"#
224 );
225 }
226}