domain_map/
lib.rs

1//! This module provides a simple datastructure that can store
2//! values associated with a domain name style key.
3//! Wildcard keys are supported.
4use config::get_or_create_sub_module;
5use mlua::prelude::LuaUserData;
6use mlua::{FromLua, Lua, MetaMethod, UserDataMethods};
7use mod_memoize::CacheValue;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap};
10use std::fmt::Debug;
11
12#[derive(Clone)]
13struct Node<V: Clone> {
14    value: Option<V>,
15    label: String,
16    children: HashMap<String, Self>,
17}
18
19impl<V: Debug + Clone> Debug for Node<V> {
20    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
21        fmt.debug_struct("Node")
22            .field("value", &self.value)
23            .field("label", &self.label)
24            .field("children", &self.children)
25            .finish()
26    }
27}
28
29/// A DomainMap is conceptually similar to a HashMap. The differences
30/// are that the keys are always domain name strings like "example.com"
31/// and that a lookup that doesn't have an exact match in the map is
32/// allowed to resolve through a wildcard entry, such as "*.example.com",
33/// if one has been inserted.
34/// A lookup for "example.com" will not match the wildcard "*.example.com"
35/// because it has fewer segments than the wildcard entry.
36#[derive(Default, Clone, Serialize, Deserialize)]
37#[serde(from = "BTreeMap<String, V>", into = "BTreeMap<String,V>")]
38pub struct DomainMap<V: Clone> {
39    top: HashMap<String, Node<V>>,
40}
41
42impl LuaUserData for DomainMap<CacheValue> {
43    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
44        mod_memoize::Memoized::impl_memoize(methods);
45        methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
46            if let Some(value) = this.get(&key) {
47                let value = value.as_lua(lua)?;
48                Ok(Some(value))
49            } else {
50                Ok(None)
51            }
52        });
53        methods.add_meta_method_mut(
54            MetaMethod::NewIndex,
55            |lua, this, (key, value): (String, mlua::Value)| {
56                let value = CacheValue::from_lua(value, lua)?;
57                this.insert(&key, value);
58                Ok(())
59            },
60        );
61    }
62}
63
64impl<V: Debug + Clone> Debug for DomainMap<V> {
65    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
66        fmt.debug_struct("DomainMap")
67            .field("top", &self.top)
68            .finish()
69    }
70}
71
72impl<V: Clone> DomainMap<V> {
73    pub fn new() -> Self {
74        Self {
75            top: HashMap::new(),
76        }
77    }
78
79    pub fn insert(&mut self, pattern: &str, value: V) {
80        let mut current = &mut self.top;
81        let mut iter = pattern.rsplit('.').peekable();
82        while let Some(seg) = iter.next() {
83            let node = current.entry(seg.to_string()).or_insert_with(|| Node {
84                value: None,
85                label: seg.to_string(),
86                children: HashMap::new(),
87            });
88
89            if iter.peek().is_none() {
90                // No further segments: this is where we set our value
91                node.value.replace(value);
92                return;
93            }
94            current = &mut node.children;
95        }
96    }
97
98    pub fn get(&self, pattern: &str) -> Option<&V> {
99        let mut current = &self.top;
100        let mut iter = pattern.rsplit('.').peekable();
101        while let Some(seg) = iter.next() {
102            match current.get(seg) {
103                Some(node) => {
104                    if iter.peek().is_none() {
105                        // This node holds our exact match
106                        return node.value.as_ref();
107                    }
108                    current = &node.children;
109                    continue;
110                }
111                None => {
112                    // No exact match; see if there is a wildcard
113                    let wild = current.get("*")?;
114                    return wild.value.as_ref();
115                }
116            }
117        }
118        None
119    }
120}
121
122impl<V: Clone> From<BTreeMap<String, V>> for DomainMap<V> {
123    fn from(map: BTreeMap<String, V>) -> Self {
124        let mut result = DomainMap::new();
125        for (k, v) in map {
126            result.insert(&k, v);
127        }
128        result
129    }
130}
131
132fn walk<'a, V: Clone>(
133    nodes: &'a HashMap<String, Node<V>>,
134    stack: &mut Vec<&'a str>,
135    result: &mut BTreeMap<String, V>,
136) {
137    for (key, value) in nodes {
138        stack.insert(0, key);
139        if let Some(v) = &value.value {
140            result.insert(stack.join("."), v.clone());
141        }
142        walk(&value.children, stack, result);
143        stack.remove(0);
144    }
145}
146
147impl<V: Clone> From<DomainMap<V>> for BTreeMap<String, V> {
148    fn from(map: DomainMap<V>) -> Self {
149        let mut result = BTreeMap::new();
150        let mut stack = vec![];
151        walk(&map.top, &mut stack, &mut result);
152
153        result
154    }
155}
156
157pub fn register(lua: &Lua) -> anyhow::Result<()> {
158    let dmap_mod = get_or_create_sub_module(lua, "domain_map")?;
159
160    dmap_mod.set(
161        "new",
162        lua.create_function(|lua, value: Option<HashMap<String, mlua::Value>>| {
163            let mut dmap: DomainMap<mod_memoize::CacheValue> = DomainMap::new();
164
165            if let Some(value) = value {
166                for (k, v) in value {
167                    let v = CacheValue::from_lua(v, lua)?;
168                    dmap.insert(&k, v);
169                }
170            }
171
172            Ok(dmap)
173        })?,
174    )?;
175
176    Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn basic() {
185        let mut map: DomainMap<u32> = DomainMap::new();
186        map.insert("*.example.com", 42);
187        map.insert("example.com", 24);
188        map.insert("omg.wtf.woot.example.com", 128);
189        println!("{map:#?}");
190
191        assert_eq!(map.get("foo.com"), None);
192        assert_eq!(map.get("example.com"), Some(&24));
193        assert_eq!(map.get("lemon.example.com"), Some(&42));
194        assert_eq!(map.get("lemon.cake.example.com"), Some(&42));
195        assert_eq!(map.get("woot.example.com"), None);
196        assert_eq!(map.get("wtf.woot.example.com"), None);
197        assert_eq!(map.get("omg.wtf.woot.example.com"), Some(&128));
198
199        let serialized: BTreeMap<_, _> = map.into();
200        k9::snapshot!(
201            &serialized,
202            r#"
203{
204    "*.example.com": 42,
205    "example.com": 24,
206    "omg.wtf.woot.example.com": 128,
207}
208"#
209        );
210
211        let round_trip: DomainMap<_> = serialized.into();
212        assert_eq!(round_trip.get("lemon.example.com"), Some(&42));
213
214        let serialized_again: BTreeMap<_, _> = round_trip.into();
215        k9::snapshot!(
216            &serialized_again,
217            r#"
218{
219    "*.example.com": 42,
220    "example.com": 24,
221    "omg.wtf.woot.example.com": 128,
222}
223"#
224        );
225    }
226}