mod_dns_resolver/
lib.rs

1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4    get_resolver, ptr_host, resolve_a_or_aaaa, reverse_ip, set_mx_concurrency_limit,
5    set_mx_negative_cache_ttl, set_mx_timeout, AggregateResolver, HickoryResolver, MailExchanger,
6    Resolver, TestResolver, UnboundResolver,
7};
8use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
9use hickory_resolver::name_server::TokioConnectionProvider;
10use hickory_resolver::proto::xfer::Protocol;
11use hickory_resolver::{Name, TokioResolver};
12use kumo_address::host_or_socket::HostOrSocketAddress;
13use mlua::{Lua, LuaSerdeExt, Value};
14use parking_lot::Mutex;
15use std::collections::HashMap;
16use std::net::{IpAddr, SocketAddr};
17use std::str::FromStr;
18use std::sync::{Arc, LazyLock};
19use std::time::Duration;
20
21static RESOLVERS: LazyLock<Mutex<HashMap<String, Arc<Box<dyn Resolver>>>>> =
22    LazyLock::new(|| Mutex::new(HashMap::new()));
23
24pub fn register(lua: &Lua) -> anyhow::Result<()> {
25    let dns_mod = get_or_create_sub_module(lua, "dns")?;
26
27    dns_mod.set(
28        "lookup_mx",
29        lua.create_async_function(|lua, domain: String| async move {
30            let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
31            Ok(lua.to_value_with(&*mx, serialize_options()))
32        })?,
33    )?;
34
35    dns_mod.set(
36        "set_mx_concurrency_limit",
37        lua.create_function(move |_lua, limit: usize| {
38            set_mx_concurrency_limit(limit);
39            Ok(())
40        })?,
41    )?;
42
43    dns_mod.set(
44        "set_mx_timeout",
45        lua.create_function(move |lua, duration: Value| {
46            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
47            set_mx_timeout(duration.into_inner()).map_err(any_err)
48        })?,
49    )?;
50
51    dns_mod.set(
52        "set_mx_negative_cache_ttl",
53        lua.create_function(move |lua, duration: Value| {
54            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
55            set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
56        })?,
57    )?;
58
59    fn get_resolver_instance(
60        opt_resolver_name: &Option<String>,
61    ) -> anyhow::Result<Arc<Box<dyn Resolver>>> {
62        if let Some(name) = opt_resolver_name {
63            return RESOLVERS
64                .lock()
65                .get(name)
66                .cloned()
67                .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"));
68        }
69
70        Ok(get_resolver())
71    }
72
73    fn get_opt_resolver(
74        opt_resolver_name: &Option<String>,
75    ) -> anyhow::Result<Option<Arc<Box<dyn Resolver>>>> {
76        if let Some(name) = opt_resolver_name {
77            let r = RESOLVERS
78                .lock()
79                .get(name)
80                .cloned()
81                .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"))?;
82            Ok(Some(r))
83        } else {
84            Ok(None)
85        }
86    }
87
88    dns_mod.set(
89        "ptr_host",
90        lua.create_function(move |_lua, ip: String| {
91            let ip: IpAddr = ip.parse().map_err(any_err)?;
92            Ok(ptr_host(ip))
93        })?,
94    )?;
95
96    dns_mod.set(
97        "reverse_ip",
98        lua.create_function(move |_lua, ip: String| {
99            let ip: IpAddr = ip.parse().map_err(any_err)?;
100            Ok(reverse_ip(ip))
101        })?,
102    )?;
103
104    dns_mod.set(
105        "rbl_lookup",
106        lua.create_async_function(
107            |_lua, (ip_str, bl_domain, opt_resolver_name): (String, String, Option<String>)| async move {
108                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
109
110                let address: HostOrSocketAddress = ip_str.parse().map_err(any_err)?;
111                let reversed_ip = reverse_ip(address.ip().ok_or_else(||mlua::Error::external(format!("{ip_str} is not a valid IpAddr or SocketAddr")))?);
112                let name = format!("{reversed_ip}.{bl_domain}.");
113
114                let answers = resolver.resolve_ip(&name).await.map_err(any_err)?;
115                match answers.first() {
116                    Some(ip) => {
117                        let txt = resolver.resolve_txt(&name).await.map(|a| a.as_txt().join("")).ok();
118                        Ok((Some(ip.to_string()), txt))
119                    }
120                    None => {
121                        Ok((None, None))
122                    }
123                }
124            },
125        )?,
126    )?;
127
128    dns_mod.set(
129        "lookup_ptr",
130        lua.create_async_function(
131            |lua, (ip_str, opt_resolver_name): (String, Option<String>)| async move {
132                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
133                let addr = std::net::IpAddr::from_str(&ip_str).map_err(any_err)?;
134                let answer = resolver.resolve_ptr(addr).await.map_err(any_err)?;
135                Ok(lua.to_value_with(&*answer, serialize_options()))
136            },
137        )?,
138    )?;
139
140    dns_mod.set(
141        "lookup_txt",
142        lua.create_async_function(
143            |_lua, (domain, opt_resolver_name): (String, Option<String>)| async move {
144                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
145                let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
146                Ok(answer.as_txt())
147            },
148        )?,
149    )?;
150
151    dns_mod.set(
152        "lookup_addr",
153        lua.create_async_function(
154            |_lua, (domain, opt_resolver_name): (String, Option<String>)| async move {
155                let opt_resolver = get_opt_resolver(&opt_resolver_name).map_err(any_err)?;
156                let result = resolve_a_or_aaaa(&domain, opt_resolver.as_ref().map(|r| &***r))
157                    .await
158                    .map_err(any_err)?;
159                let result: Vec<String> = result
160                    .into_iter()
161                    .map(|item| item.addr.to_string())
162                    .collect();
163                Ok(result)
164            },
165        )?,
166    )?;
167
168    #[derive(serde::Deserialize, Debug)]
169    #[serde(deny_unknown_fields)]
170    struct TestResolverConfig {
171        zones: Vec<String>,
172    }
173
174    impl TestResolverConfig {
175        fn make_resolver(&self) -> anyhow::Result<TestResolver> {
176            let mut resolver = TestResolver::default();
177
178            for zone in &self.zones {
179                resolver = resolver
180                    .with_zone(zone)
181                    .map_err(|err| anyhow::anyhow!("{err}"))?;
182            }
183
184            Ok(resolver)
185        }
186    }
187
188    #[derive(serde::Deserialize, Debug)]
189    enum KumoResolverConfig {
190        Hickory(DnsConfig),
191        HickorySystemConfig,
192        Unbound(DnsConfig),
193        Test(TestResolverConfig),
194        Aggregate(Vec<KumoResolverConfig>),
195    }
196
197    impl KumoResolverConfig {
198        fn make_resolver(&self) -> anyhow::Result<Box<dyn Resolver>> {
199            match self {
200                Self::Hickory(config) => Ok(Box::new(config.make_hickory()?)),
201                Self::HickorySystemConfig => Ok(Box::new(HickoryResolver::new()?)),
202                Self::Unbound(config) => Ok(Box::new(config.make_unbound()?)),
203                Self::Test(config) => Ok(Box::new(config.make_resolver()?)),
204                Self::Aggregate(config) => {
205                    let mut resolver = AggregateResolver::new();
206                    for c in config {
207                        resolver.push_resolver(c.make_resolver()?);
208                    }
209                    Ok(Box::new(resolver))
210                }
211            }
212        }
213    }
214
215    #[derive(serde::Deserialize, Debug)]
216    #[serde(deny_unknown_fields)]
217    struct DnsConfig {
218        #[serde(default)]
219        domain: Option<String>,
220        #[serde(default)]
221        search: Vec<String>,
222        #[serde(default)]
223        name_servers: Vec<NameServer>,
224        #[serde(default)]
225        options: ResolverOpts,
226    }
227
228    impl DnsConfig {
229        fn make_hickory(&self) -> anyhow::Result<HickoryResolver> {
230            let mut config = ResolverConfig::new();
231            if let Some(dom) = &self.domain {
232                config.set_domain(
233                    Name::from_str_relaxed(&dom).with_context(|| format!("domain: '{dom}'"))?,
234                );
235            }
236            for s in &self.search {
237                let name = Name::from_str_relaxed(&s).with_context(|| format!("search: '{s}'"))?;
238                config.add_search(name);
239            }
240
241            for ns in &self.name_servers {
242                config.add_name_server(match ns {
243                    NameServer::Ip(ip) => {
244                        let ip: SocketAddr =
245                            ip.parse().with_context(|| format!("name server: '{ip}'"))?;
246                        NameServerConfig::new(ip, Protocol::Udp)
247                    }
248                    NameServer::Detailed {
249                        socket_addr,
250                        protocol,
251                        trust_negative_responses,
252                        bind_addr,
253                    } => {
254                        let ip: SocketAddr = socket_addr
255                            .parse()
256                            .with_context(|| format!("name server: '{socket_addr}'"))?;
257                        let mut c = NameServerConfig::new(ip, protocol.clone());
258
259                        c.trust_negative_responses = *trust_negative_responses;
260
261                        if let Some(bind) = bind_addr {
262                            let addr: SocketAddr = bind.parse().with_context(|| {
263                                format!("name server: '{socket_addr}' bind_addr: '{bind}'")
264                            })?;
265                            c.bind_addr.replace(addr);
266                        }
267
268                        c
269                    }
270                });
271            }
272
273            let mut builder =
274                TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
275            *builder.options_mut() = self.options.clone();
276            Ok(HickoryResolver::from(builder.build()))
277        }
278
279        fn make_unbound(&self) -> anyhow::Result<UnboundResolver> {
280            let context = libunbound::Context::new()?;
281
282            for ns in &self.name_servers {
283                let addr = match ns {
284                    NameServer::Ip(ip) => {
285                        ip.parse().with_context(|| format!("name server: '{ip}'"))?
286                    }
287                    NameServer::Detailed { socket_addr, .. } => socket_addr
288                        .parse()
289                        .with_context(|| format!("name server: '{socket_addr}'"))?,
290                };
291                context.set_forward(Some(addr)).context("set_forward")?;
292            }
293
294            // TODO: expose a way to provide unbound configuration
295            // options to this code
296
297            if self.options.validate {
298                context
299                    .add_builtin_trust_anchors()
300                    .context("add_builtin_trust_anchors")?;
301            }
302            if matches!(
303                self.options.use_hosts_file,
304                ResolveHosts::Always | ResolveHosts::Auto
305            ) {
306                context.load_hosts(None).context("load_hosts")?;
307            }
308
309            let context = context
310                .into_async()
311                .context("make async resolver context")?;
312
313            Ok(UnboundResolver::from(context))
314        }
315    }
316
317    #[derive(serde::Deserialize, Debug)]
318    #[serde(untagged)]
319    #[serde(deny_unknown_fields)]
320    enum NameServer {
321        Ip(String),
322        Detailed {
323            socket_addr: String,
324            #[serde(default)]
325            protocol: Protocol,
326            #[serde(default)]
327            trust_negative_responses: bool,
328            #[serde(default)]
329            bind_addr: Option<String>,
330        },
331    }
332
333    dns_mod.set(
334        "configure_resolver",
335        lua.create_function(move |lua, config: mlua::Value| {
336            match lua.from_value::<KumoResolverConfig>(config.clone()) {
337                Ok(config) => {
338                    let resolver = config.make_resolver().map_err(any_err)?;
339                    dns_resolver::reconfigure_resolver(resolver);
340                    Ok(())
341                }
342                Err(err1) => match lua.from_value::<DnsConfig>(config) {
343                    Ok(config) => {
344                        let resolver = config.make_hickory().map_err(any_err)?;
345                        dns_resolver::reconfigure_resolver(resolver);
346                        Ok(())
347                    }
348                    Err(err2) => {
349                        Err(mlua::Error::external(format!("failed to parse config as either KumoResolverConfig ({err1:#}) or DnsConfig ({err2:#})")))
350                    }
351                }
352            }
353
354        })?,
355    )?;
356
357    dns_mod.set(
358        "define_resolver",
359        lua.create_function(move |lua, (name, config): (String, mlua::Value)| {
360            let config = lua
361                .from_value::<KumoResolverConfig>(config.clone())
362                .map_err(any_err)?;
363            let resolver = config.make_resolver().map_err(any_err)?;
364
365            RESOLVERS.lock().insert(name, resolver.into());
366
367            Ok(())
368        })?,
369    )?;
370
371    dns_mod.set(
372        "configure_unbound_resolver",
373        lua.create_function(move |lua, config: mlua::Value| {
374            let config: DnsConfig = lua.from_value(config)?;
375            let resolver = config.make_unbound().map_err(any_err)?;
376            dns_resolver::reconfigure_resolver(resolver);
377            Ok(())
378        })?,
379    )?;
380
381    dns_mod.set(
382        "configure_test_resolver",
383        lua.create_function(move |_lua, zones: Vec<String>| {
384            let config = TestResolverConfig { zones };
385            let resolver = config.make_resolver().map_err(any_err)?;
386            dns_resolver::reconfigure_resolver(resolver);
387            Ok(())
388        })?,
389    )?;
390
391    Ok(())
392}