mod_dns_resolver/
lib.rs

1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4    get_resolver, resolve_a_or_aaaa, set_mx_concurrency_limit, set_mx_negative_cache_ttl,
5    set_mx_timeout, HickoryResolver, MailExchanger, TestResolver, UnboundResolver,
6};
7use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
8use hickory_resolver::name_server::TokioConnectionProvider;
9use hickory_resolver::proto::xfer::Protocol;
10use hickory_resolver::{Name, TokioResolver};
11use mlua::{Lua, LuaSerdeExt, Value};
12use std::net::SocketAddr;
13use std::str::FromStr;
14use std::time::Duration;
15
16pub fn register(lua: &Lua) -> anyhow::Result<()> {
17    let dns_mod = get_or_create_sub_module(lua, "dns")?;
18
19    dns_mod.set(
20        "lookup_mx",
21        lua.create_async_function(|lua, domain: String| async move {
22            let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
23            Ok(lua.to_value_with(&*mx, serialize_options()))
24        })?,
25    )?;
26
27    dns_mod.set(
28        "set_mx_concurrency_limit",
29        lua.create_function(move |_lua, limit: usize| {
30            set_mx_concurrency_limit(limit);
31            Ok(())
32        })?,
33    )?;
34
35    dns_mod.set(
36        "set_mx_timeout",
37        lua.create_function(move |lua, duration: Value| {
38            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
39            set_mx_timeout(duration.into_inner()).map_err(any_err)
40        })?,
41    )?;
42
43    dns_mod.set(
44        "set_mx_negative_cache_ttl",
45        lua.create_function(move |lua, duration: Value| {
46            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
47            set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
48        })?,
49    )?;
50
51    dns_mod.set(
52        "lookup_ptr",
53        lua.create_async_function(|lua, ip_str: String| async move {
54            let resolver = get_resolver();
55            let addr = std::net::IpAddr::from_str(&ip_str).map_err(any_err)?;
56            let answer = resolver.resolve_ptr(addr).await.map_err(any_err)?;
57            Ok(lua.to_value_with(&*answer, serialize_options()))
58        })?,
59    )?;
60
61    dns_mod.set(
62        "lookup_txt",
63        lua.create_async_function(|_lua, domain: String| async move {
64            let resolver = get_resolver();
65            let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
66            Ok(answer.as_txt())
67        })?,
68    )?;
69
70    dns_mod.set(
71        "lookup_addr",
72        lua.create_async_function(|_lua, domain: String| async move {
73            let result = resolve_a_or_aaaa(&domain).await.map_err(any_err)?;
74            let result: Vec<String> = result
75                .into_iter()
76                .map(|item| item.addr.to_string())
77                .collect();
78            Ok(result)
79        })?,
80    )?;
81
82    #[derive(serde::Deserialize, Debug)]
83    #[serde(deny_unknown_fields)]
84    struct DnsConfig {
85        #[serde(default)]
86        domain: Option<String>,
87        #[serde(default)]
88        search: Vec<String>,
89        #[serde(default)]
90        name_servers: Vec<NameServer>,
91        #[serde(default)]
92        options: ResolverOpts,
93    }
94
95    #[derive(serde::Deserialize, Debug)]
96    #[serde(untagged)]
97    #[serde(deny_unknown_fields)]
98    enum NameServer {
99        Ip(String),
100        Detailed {
101            socket_addr: String,
102            #[serde(default)]
103            protocol: Protocol,
104            #[serde(default)]
105            trust_negative_responses: bool,
106            #[serde(default)]
107            bind_addr: Option<String>,
108        },
109    }
110
111    dns_mod.set(
112        "configure_resolver",
113        lua.create_function(move |lua, config: mlua::Value| {
114            let config: DnsConfig = lua.from_value(config)?;
115
116            let mut r_config = ResolverConfig::new();
117            if let Some(dom) = config.domain {
118                r_config.set_domain(
119                    Name::from_str_relaxed(&dom)
120                        .with_context(|| format!("domain: '{dom}'"))
121                        .map_err(any_err)?,
122                );
123            }
124            for s in config.search {
125                let name = Name::from_str_relaxed(&s)
126                    .with_context(|| format!("search: '{s}'"))
127                    .map_err(any_err)?;
128                r_config.add_search(name);
129            }
130
131            for ns in config.name_servers {
132                r_config.add_name_server(match ns {
133                    NameServer::Ip(ip) => {
134                        let ip: SocketAddr = ip
135                            .parse()
136                            .with_context(|| format!("name server: '{ip}'"))
137                            .map_err(any_err)?;
138                        NameServerConfig::new(ip, Protocol::Udp)
139                    }
140                    NameServer::Detailed {
141                        socket_addr,
142                        protocol,
143                        trust_negative_responses,
144                        bind_addr,
145                    } => {
146                        let ip: SocketAddr = socket_addr
147                            .parse()
148                            .with_context(|| format!("name server: '{socket_addr}'"))
149                            .map_err(any_err)?;
150                        let mut c = NameServerConfig::new(ip, protocol);
151
152                        c.trust_negative_responses = trust_negative_responses;
153
154                        if let Some(bind) = bind_addr {
155                            let addr: SocketAddr = bind
156                                .parse()
157                                .with_context(|| {
158                                    format!("name server: '{socket_addr}' bind_addr: '{bind}'")
159                                })
160                                .map_err(any_err)?;
161                            c.bind_addr.replace(addr);
162                        }
163
164                        c
165                    }
166                });
167            }
168
169            let mut builder =
170                TokioResolver::builder_with_config(r_config, TokioConnectionProvider::default());
171            *builder.options_mut() = config.options;
172            dns_resolver::reconfigure_resolver(HickoryResolver::from(builder.build()));
173
174            Ok(())
175        })?,
176    )?;
177
178    dns_mod.set(
179        "configure_unbound_resolver",
180        lua.create_function(move |lua, config: mlua::Value| {
181            let config: DnsConfig = lua.from_value(config)?;
182
183            let context = libunbound::Context::new().map_err(any_err)?;
184
185            for ns in config.name_servers {
186                let addr = match ns {
187                    NameServer::Ip(ip) => ip
188                        .parse()
189                        .with_context(|| format!("name server: '{ip}'"))
190                        .map_err(any_err)?,
191                    NameServer::Detailed { socket_addr, .. } => socket_addr
192                        .parse()
193                        .with_context(|| format!("name server: '{socket_addr}'"))
194                        .map_err(any_err)?,
195                };
196                context
197                    .set_forward(Some(addr))
198                    .context("set_forward")
199                    .map_err(any_err)?;
200            }
201
202            // TODO: expose a way to provide unbound configuration
203            // options to this code
204
205            if config.options.validate {
206                context
207                    .add_builtin_trust_anchors()
208                    .context("add_builtin_trust_anchors")
209                    .map_err(any_err)?;
210            }
211            if matches!(
212                config.options.use_hosts_file,
213                ResolveHosts::Always | ResolveHosts::Auto
214            ) {
215                context
216                    .load_hosts(None)
217                    .context("load_hosts")
218                    .map_err(any_err)?;
219            }
220
221            let context = context
222                .into_async()
223                .context("make async resolver context")
224                .map_err(any_err)?;
225
226            dns_resolver::reconfigure_resolver(UnboundResolver::from(context));
227
228            Ok(())
229        })?,
230    )?;
231
232    dns_mod.set(
233        "configure_test_resolver",
234        lua.create_function(move |_lua, zones: Vec<String>| {
235            let mut resolver = TestResolver::default();
236            for zone in &zones {
237                resolver = resolver.with_zone(zone);
238            }
239
240            dns_resolver::reconfigure_resolver(resolver);
241            Ok(())
242        })?,
243    )?;
244
245    Ok(())
246}