mod_dns_resolver/
lib.rs

1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options, SerdeWrappedValue};
3use dns_resolver::{
4    get_resolver, ptr_host, resolve_a_or_aaaa, reverse_ip, set_mx_concurrency_limit,
5    set_mx_negative_cache_ttl, set_mx_timeout, AggregateResolver, HickoryResolver,
6    IpLookupStrategy, MailExchanger, Resolver, TestResolver, UnboundResolver,
7};
8use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
9use hickory_resolver::name_server::TokioConnectionProvider;
10use hickory_resolver::proto::xfer::Protocol;
11use hickory_resolver::{Name, TokioResolver};
12use kumo_address::host_or_socket::HostOrSocketAddress;
13use mlua::{Lua, LuaSerdeExt, Value};
14use parking_lot::Mutex;
15use std::collections::HashMap;
16use std::net::{IpAddr, SocketAddr};
17use std::str::FromStr;
18use std::sync::{Arc, LazyLock};
19use std::time::Duration;
20
21static RESOLVERS: LazyLock<Mutex<HashMap<String, Arc<Box<dyn Resolver>>>>> =
22    LazyLock::new(|| Mutex::new(HashMap::new()));
23
24pub fn get_resolver_instance(
25    opt_resolver_name: &Option<String>,
26) -> anyhow::Result<Arc<Box<dyn Resolver>>> {
27    if let Some(name) = opt_resolver_name {
28        return RESOLVERS
29            .lock()
30            .get(name)
31            .cloned()
32            .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"));
33    }
34
35    Ok(get_resolver())
36}
37
38pub fn get_opt_resolver(
39    opt_resolver_name: &Option<String>,
40) -> anyhow::Result<Option<Arc<Box<dyn Resolver>>>> {
41    if let Some(name) = opt_resolver_name {
42        let r = RESOLVERS
43            .lock()
44            .get(name)
45            .cloned()
46            .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"))?;
47        Ok(Some(r))
48    } else {
49        Ok(None)
50    }
51}
52
53pub fn register(lua: &Lua) -> anyhow::Result<()> {
54    let dns_mod = get_or_create_sub_module(lua, "dns")?;
55
56    dns_mod.set(
57        "lookup_mx",
58        lua.create_async_function(|lua, domain: String| async move {
59            let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
60            Ok(lua.to_value_with(&*mx, serialize_options()))
61        })?,
62    )?;
63
64    dns_mod.set(
65        "set_mx_concurrency_limit",
66        lua.create_function(move |_lua, limit: usize| {
67            set_mx_concurrency_limit(limit);
68            Ok(())
69        })?,
70    )?;
71
72    dns_mod.set(
73        "set_mx_timeout",
74        lua.create_function(move |lua, duration: Value| {
75            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
76            set_mx_timeout(duration.into_inner()).map_err(any_err)
77        })?,
78    )?;
79
80    dns_mod.set(
81        "set_mx_negative_cache_ttl",
82        lua.create_function(move |lua, duration: Value| {
83            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
84            set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
85        })?,
86    )?;
87
88    dns_mod.set(
89        "ptr_host",
90        lua.create_function(move |_lua, ip: String| {
91            let ip: IpAddr = ip.parse().map_err(any_err)?;
92            Ok(ptr_host(ip))
93        })?,
94    )?;
95
96    dns_mod.set(
97        "reverse_ip",
98        lua.create_function(move |_lua, ip: String| {
99            let ip: IpAddr = ip.parse().map_err(any_err)?;
100            Ok(reverse_ip(ip))
101        })?,
102    )?;
103
104    dns_mod.set(
105        "rbl_lookup",
106        lua.create_async_function(
107            |_lua, (ip_str, bl_domain, opt_resolver_name): (String, String, Option<String>)| async move {
108                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
109
110                let address: HostOrSocketAddress = ip_str.parse().map_err(any_err)?;
111                let reversed_ip = reverse_ip(address.ip().ok_or_else(||mlua::Error::external(format!("{ip_str} is not a valid IpAddr or SocketAddr")))?);
112                let name = format!("{reversed_ip}.{bl_domain}.");
113
114                let answers = resolver.resolve_ip(&name).await.map_err(any_err)?;
115                match answers.first() {
116                    Some(ip) => {
117                        let txt = resolver.resolve_txt(&name).await.map(|a| a.as_txt().join("")).ok();
118                        Ok((Some(ip.to_string()), txt))
119                    }
120                    None => {
121                        Ok((None, None))
122                    }
123                }
124            },
125        )?,
126    )?;
127
128    dns_mod.set(
129        "lookup_ptr",
130        lua.create_async_function(
131            |lua, (ip_str, opt_resolver_name): (String, Option<String>)| async move {
132                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
133                let addr = std::net::IpAddr::from_str(&ip_str).map_err(any_err)?;
134                let answer = resolver.resolve_ptr(addr).await.map_err(any_err)?;
135                Ok(lua.to_value_with(&*answer, serialize_options()))
136            },
137        )?,
138    )?;
139
140    dns_mod.set(
141        "lookup_txt",
142        lua.create_async_function(
143            |_lua, (domain, opt_resolver_name): (String, Option<String>)| async move {
144                let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
145                let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
146                Ok(answer.as_txt())
147            },
148        )?,
149    )?;
150
151    dns_mod.set(
152        "lookup_addr",
153        lua.create_async_function(
154            |_lua,
155             (domain, opt_resolver_name, strategy): (
156                String,
157                Option<String>,
158                Option<SerdeWrappedValue<IpLookupStrategy>>,
159            )| async move {
160                let opt_resolver = get_opt_resolver(&opt_resolver_name).map_err(any_err)?;
161                let result = resolve_a_or_aaaa(
162                    &domain,
163                    opt_resolver.as_ref().map(|r| &***r),
164                    strategy.map(|v| v.0).unwrap_or_default(),
165                )
166                .await
167                .map_err(any_err)?;
168                let result: Vec<String> = result
169                    .into_iter()
170                    .map(|item| item.addr.to_string())
171                    .collect();
172                Ok(result)
173            },
174        )?,
175    )?;
176
177    #[derive(serde::Deserialize, Debug)]
178    #[serde(deny_unknown_fields)]
179    struct TestResolverConfig {
180        zones: Vec<String>,
181    }
182
183    impl TestResolverConfig {
184        fn make_resolver(&self) -> anyhow::Result<TestResolver> {
185            let mut resolver = TestResolver::default();
186
187            for zone in &self.zones {
188                resolver = resolver
189                    .with_zone(zone)
190                    .map_err(|err| anyhow::anyhow!("{err}"))?;
191            }
192
193            Ok(resolver)
194        }
195    }
196
197    #[derive(serde::Deserialize, Debug)]
198    enum KumoResolverConfig {
199        Hickory(DnsConfig),
200        HickorySystemConfig,
201        Unbound(DnsConfig),
202        Test(TestResolverConfig),
203        Aggregate(Vec<KumoResolverConfig>),
204    }
205
206    impl KumoResolverConfig {
207        fn make_resolver(&self) -> anyhow::Result<Box<dyn Resolver>> {
208            match self {
209                Self::Hickory(config) => Ok(Box::new(config.make_hickory()?)),
210                Self::HickorySystemConfig => Ok(Box::new(HickoryResolver::new()?)),
211                Self::Unbound(config) => Ok(Box::new(config.make_unbound()?)),
212                Self::Test(config) => Ok(Box::new(config.make_resolver()?)),
213                Self::Aggregate(config) => {
214                    let mut resolver = AggregateResolver::new();
215                    for c in config {
216                        resolver.push_resolver(c.make_resolver()?);
217                    }
218                    Ok(Box::new(resolver))
219                }
220            }
221        }
222    }
223
224    #[derive(serde::Deserialize, Debug)]
225    #[serde(deny_unknown_fields)]
226    struct DnsConfig {
227        #[serde(default)]
228        domain: Option<String>,
229        #[serde(default)]
230        search: Vec<String>,
231        #[serde(default)]
232        name_servers: Vec<NameServer>,
233        #[serde(default)]
234        options: ResolverOpts,
235    }
236
237    impl DnsConfig {
238        fn make_hickory(&self) -> anyhow::Result<HickoryResolver> {
239            let mut config = ResolverConfig::new();
240            if let Some(dom) = &self.domain {
241                config.set_domain(
242                    Name::from_str_relaxed(&dom).with_context(|| format!("domain: '{dom}'"))?,
243                );
244            }
245            for s in &self.search {
246                let name = Name::from_str_relaxed(&s).with_context(|| format!("search: '{s}'"))?;
247                config.add_search(name);
248            }
249
250            for ns in &self.name_servers {
251                config.add_name_server(match ns {
252                    NameServer::Ip(ip) => {
253                        let ip: SocketAddr =
254                            ip.parse().with_context(|| format!("name server: '{ip}'"))?;
255                        NameServerConfig::new(ip, Protocol::Udp)
256                    }
257                    NameServer::Detailed {
258                        socket_addr,
259                        protocol,
260                        trust_negative_responses,
261                        bind_addr,
262                    } => {
263                        let ip: SocketAddr = socket_addr
264                            .parse()
265                            .with_context(|| format!("name server: '{socket_addr}'"))?;
266                        let mut c = NameServerConfig::new(ip, protocol.clone());
267
268                        c.trust_negative_responses = *trust_negative_responses;
269
270                        if let Some(bind) = bind_addr {
271                            let addr: SocketAddr = bind.parse().with_context(|| {
272                                format!("name server: '{socket_addr}' bind_addr: '{bind}'")
273                            })?;
274                            c.bind_addr.replace(addr);
275                        }
276
277                        c
278                    }
279                });
280            }
281
282            let mut builder =
283                TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
284            *builder.options_mut() = self.options.clone();
285            Ok(HickoryResolver::from(builder.build()))
286        }
287
288        fn make_unbound(&self) -> anyhow::Result<UnboundResolver> {
289            let context = libunbound::Context::new()?;
290
291            for ns in &self.name_servers {
292                let addr = match ns {
293                    NameServer::Ip(ip) => {
294                        ip.parse().with_context(|| format!("name server: '{ip}'"))?
295                    }
296                    NameServer::Detailed { socket_addr, .. } => socket_addr
297                        .parse()
298                        .with_context(|| format!("name server: '{socket_addr}'"))?,
299                };
300                context.set_forward(Some(addr)).context("set_forward")?;
301            }
302
303            // TODO: expose a way to provide unbound configuration
304            // options to this code
305
306            if self.options.validate {
307                context
308                    .add_builtin_trust_anchors()
309                    .context("add_builtin_trust_anchors")?;
310            }
311            if matches!(
312                self.options.use_hosts_file,
313                ResolveHosts::Always | ResolveHosts::Auto
314            ) {
315                context.load_hosts(None).context("load_hosts")?;
316            }
317
318            let context = context
319                .into_async()
320                .context("make async resolver context")?;
321
322            Ok(UnboundResolver::from(context))
323        }
324    }
325
326    #[derive(serde::Deserialize, Debug)]
327    #[serde(untagged)]
328    #[serde(deny_unknown_fields)]
329    enum NameServer {
330        Ip(String),
331        Detailed {
332            socket_addr: String,
333            #[serde(default)]
334            protocol: Protocol,
335            #[serde(default)]
336            trust_negative_responses: bool,
337            #[serde(default)]
338            bind_addr: Option<String>,
339        },
340    }
341
342    dns_mod.set(
343        "configure_resolver",
344        lua.create_function(move |lua, config: mlua::Value| {
345            match lua.from_value::<KumoResolverConfig>(config.clone()) {
346                Ok(config) => {
347                    let resolver = config.make_resolver().map_err(any_err)?;
348                    dns_resolver::reconfigure_resolver(resolver);
349                    Ok(())
350                }
351                Err(err1) => match lua.from_value::<DnsConfig>(config) {
352                    Ok(config) => {
353                        let resolver = config.make_hickory().map_err(any_err)?;
354                        dns_resolver::reconfigure_resolver(resolver);
355                        Ok(())
356                    }
357                    Err(err2) => {
358                        Err(mlua::Error::external(format!("failed to parse config as either KumoResolverConfig ({err1:#}) or DnsConfig ({err2:#})")))
359                    }
360                }
361            }
362
363        })?,
364    )?;
365
366    dns_mod.set(
367        "define_resolver",
368        lua.create_function(move |lua, (name, config): (String, mlua::Value)| {
369            let config = lua
370                .from_value::<KumoResolverConfig>(config.clone())
371                .map_err(any_err)?;
372            let resolver = config.make_resolver().map_err(any_err)?;
373
374            RESOLVERS.lock().insert(name, resolver.into());
375
376            Ok(())
377        })?,
378    )?;
379
380    dns_mod.set(
381        "configure_unbound_resolver",
382        lua.create_function(move |lua, config: mlua::Value| {
383            let config: DnsConfig = lua.from_value(config)?;
384            let resolver = config.make_unbound().map_err(any_err)?;
385            dns_resolver::reconfigure_resolver(resolver);
386            Ok(())
387        })?,
388    )?;
389
390    dns_mod.set(
391        "configure_test_resolver",
392        lua.create_function(move |_lua, zones: Vec<String>| {
393            let config = TestResolverConfig { zones };
394            let resolver = config.make_resolver().map_err(any_err)?;
395            dns_resolver::reconfigure_resolver(resolver);
396            Ok(())
397        })?,
398    )?;
399
400    Ok(())
401}