mod_dns_resolver/
lib.rs

1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4    get_resolver, resolve_a_or_aaaa, set_mx_concurrency_limit, set_mx_negative_cache_ttl,
5    set_mx_timeout, HickoryResolver, MailExchanger, TestResolver, UnboundResolver,
6};
7use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
8use hickory_resolver::name_server::TokioConnectionProvider;
9use hickory_resolver::proto::xfer::Protocol;
10use hickory_resolver::{Name, TokioResolver};
11use mlua::{Lua, LuaSerdeExt, Value};
12use std::net::SocketAddr;
13use std::time::Duration;
14
15pub fn register(lua: &Lua) -> anyhow::Result<()> {
16    let dns_mod = get_or_create_sub_module(lua, "dns")?;
17
18    dns_mod.set(
19        "lookup_mx",
20        lua.create_async_function(|lua, domain: String| async move {
21            let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
22            Ok(lua.to_value_with(&*mx, serialize_options()))
23        })?,
24    )?;
25
26    dns_mod.set(
27        "set_mx_concurrency_limit",
28        lua.create_function(move |_lua, limit: usize| {
29            set_mx_concurrency_limit(limit);
30            Ok(())
31        })?,
32    )?;
33
34    dns_mod.set(
35        "set_mx_timeout",
36        lua.create_function(move |lua, duration: Value| {
37            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
38            set_mx_timeout(duration.into_inner()).map_err(any_err)
39        })?,
40    )?;
41
42    dns_mod.set(
43        "set_mx_negative_cache_ttl",
44        lua.create_function(move |lua, duration: Value| {
45            let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
46            set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
47        })?,
48    )?;
49
50    dns_mod.set(
51        "lookup_txt",
52        lua.create_async_function(|_lua, domain: String| async move {
53            let resolver = get_resolver();
54            let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
55            Ok(answer.as_txt())
56        })?,
57    )?;
58
59    dns_mod.set(
60        "lookup_addr",
61        lua.create_async_function(|_lua, domain: String| async move {
62            let result = resolve_a_or_aaaa(&domain).await.map_err(any_err)?;
63            let result: Vec<String> = result
64                .into_iter()
65                .map(|item| item.addr.to_string())
66                .collect();
67            Ok(result)
68        })?,
69    )?;
70
71    #[derive(serde::Deserialize, Debug)]
72    #[serde(deny_unknown_fields)]
73    struct DnsConfig {
74        #[serde(default)]
75        domain: Option<String>,
76        #[serde(default)]
77        search: Vec<String>,
78        #[serde(default)]
79        name_servers: Vec<NameServer>,
80        #[serde(default)]
81        options: ResolverOpts,
82    }
83
84    #[derive(serde::Deserialize, Debug)]
85    #[serde(untagged)]
86    #[serde(deny_unknown_fields)]
87    enum NameServer {
88        Ip(String),
89        Detailed {
90            socket_addr: String,
91            #[serde(default)]
92            protocol: Protocol,
93            #[serde(default)]
94            trust_negative_responses: bool,
95            #[serde(default)]
96            bind_addr: Option<String>,
97        },
98    }
99
100    dns_mod.set(
101        "configure_resolver",
102        lua.create_function(move |lua, config: mlua::Value| {
103            let config: DnsConfig = lua.from_value(config)?;
104
105            let mut r_config = ResolverConfig::new();
106            if let Some(dom) = config.domain {
107                r_config.set_domain(
108                    Name::from_str_relaxed(&dom)
109                        .with_context(|| format!("domain: '{dom}'"))
110                        .map_err(any_err)?,
111                );
112            }
113            for s in config.search {
114                let name = Name::from_str_relaxed(&s)
115                    .with_context(|| format!("search: '{s}'"))
116                    .map_err(any_err)?;
117                r_config.add_search(name);
118            }
119
120            for ns in config.name_servers {
121                r_config.add_name_server(match ns {
122                    NameServer::Ip(ip) => {
123                        let ip: SocketAddr = ip
124                            .parse()
125                            .with_context(|| format!("name server: '{ip}'"))
126                            .map_err(any_err)?;
127                        NameServerConfig::new(ip, Protocol::Udp)
128                    }
129                    NameServer::Detailed {
130                        socket_addr,
131                        protocol,
132                        trust_negative_responses,
133                        bind_addr,
134                    } => {
135                        let ip: SocketAddr = socket_addr
136                            .parse()
137                            .with_context(|| format!("name server: '{socket_addr}'"))
138                            .map_err(any_err)?;
139                        let mut c = NameServerConfig::new(ip, protocol);
140
141                        c.trust_negative_responses = trust_negative_responses;
142
143                        if let Some(bind) = bind_addr {
144                            let addr: SocketAddr = bind
145                                .parse()
146                                .with_context(|| {
147                                    format!("name server: '{socket_addr}' bind_addr: '{bind}'")
148                                })
149                                .map_err(any_err)?;
150                            c.bind_addr.replace(addr);
151                        }
152
153                        c
154                    }
155                });
156            }
157
158            let mut builder =
159                TokioResolver::builder_with_config(r_config, TokioConnectionProvider::default());
160            *builder.options_mut() = config.options;
161            dns_resolver::reconfigure_resolver(HickoryResolver::from(builder.build()));
162
163            Ok(())
164        })?,
165    )?;
166
167    dns_mod.set(
168        "configure_unbound_resolver",
169        lua.create_function(move |lua, config: mlua::Value| {
170            let config: DnsConfig = lua.from_value(config)?;
171
172            let context = libunbound::Context::new().map_err(any_err)?;
173
174            for ns in config.name_servers {
175                let addr = match ns {
176                    NameServer::Ip(ip) => ip
177                        .parse()
178                        .with_context(|| format!("name server: '{ip}'"))
179                        .map_err(any_err)?,
180                    NameServer::Detailed { socket_addr, .. } => socket_addr
181                        .parse()
182                        .with_context(|| format!("name server: '{socket_addr}'"))
183                        .map_err(any_err)?,
184                };
185                context
186                    .set_forward(Some(addr))
187                    .context("set_forward")
188                    .map_err(any_err)?;
189            }
190
191            // TODO: expose a way to provide unbound configuration
192            // options to this code
193
194            if config.options.validate {
195                context
196                    .add_builtin_trust_anchors()
197                    .context("add_builtin_trust_anchors")
198                    .map_err(any_err)?;
199            }
200            if matches!(
201                config.options.use_hosts_file,
202                ResolveHosts::Always | ResolveHosts::Auto
203            ) {
204                context
205                    .load_hosts(None)
206                    .context("load_hosts")
207                    .map_err(any_err)?;
208            }
209
210            let context = context
211                .into_async()
212                .context("make async resolver context")
213                .map_err(any_err)?;
214
215            dns_resolver::reconfigure_resolver(UnboundResolver::from(context));
216
217            Ok(())
218        })?,
219    )?;
220
221    dns_mod.set(
222        "configure_test_resolver",
223        lua.create_function(move |_lua, zones: Vec<String>| {
224            let mut resolver = TestResolver::default();
225            for zone in &zones {
226                resolver = resolver.with_zone(zone);
227            }
228
229            dns_resolver::reconfigure_resolver(resolver);
230            Ok(())
231        })?,
232    )?;
233
234    Ok(())
235}