1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4 get_resolver, resolve_a_or_aaaa, set_mx_concurrency_limit, set_mx_negative_cache_ttl,
5 set_mx_timeout, HickoryResolver, MailExchanger, TestResolver, UnboundResolver,
6};
7use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
8use hickory_resolver::name_server::TokioConnectionProvider;
9use hickory_resolver::proto::xfer::Protocol;
10use hickory_resolver::{Name, TokioResolver};
11use mlua::{Lua, LuaSerdeExt, Value};
12use std::net::SocketAddr;
13use std::str::FromStr;
14use std::time::Duration;
15
16pub fn register(lua: &Lua) -> anyhow::Result<()> {
17 let dns_mod = get_or_create_sub_module(lua, "dns")?;
18
19 dns_mod.set(
20 "lookup_mx",
21 lua.create_async_function(|lua, domain: String| async move {
22 let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
23 Ok(lua.to_value_with(&*mx, serialize_options()))
24 })?,
25 )?;
26
27 dns_mod.set(
28 "set_mx_concurrency_limit",
29 lua.create_function(move |_lua, limit: usize| {
30 set_mx_concurrency_limit(limit);
31 Ok(())
32 })?,
33 )?;
34
35 dns_mod.set(
36 "set_mx_timeout",
37 lua.create_function(move |lua, duration: Value| {
38 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
39 set_mx_timeout(duration.into_inner()).map_err(any_err)
40 })?,
41 )?;
42
43 dns_mod.set(
44 "set_mx_negative_cache_ttl",
45 lua.create_function(move |lua, duration: Value| {
46 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
47 set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
48 })?,
49 )?;
50
51 dns_mod.set(
52 "lookup_ptr",
53 lua.create_async_function(|lua, ip_str: String| async move {
54 let resolver = get_resolver();
55 let addr = std::net::IpAddr::from_str(&ip_str).map_err(any_err)?;
56 let answer = resolver.resolve_ptr(addr).await.map_err(any_err)?;
57 Ok(lua.to_value_with(&*answer, serialize_options()))
58 })?,
59 )?;
60
61 dns_mod.set(
62 "lookup_txt",
63 lua.create_async_function(|_lua, domain: String| async move {
64 let resolver = get_resolver();
65 let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
66 Ok(answer.as_txt())
67 })?,
68 )?;
69
70 dns_mod.set(
71 "lookup_addr",
72 lua.create_async_function(|_lua, domain: String| async move {
73 let result = resolve_a_or_aaaa(&domain).await.map_err(any_err)?;
74 let result: Vec<String> = result
75 .into_iter()
76 .map(|item| item.addr.to_string())
77 .collect();
78 Ok(result)
79 })?,
80 )?;
81
82 #[derive(serde::Deserialize, Debug)]
83 #[serde(deny_unknown_fields)]
84 struct DnsConfig {
85 #[serde(default)]
86 domain: Option<String>,
87 #[serde(default)]
88 search: Vec<String>,
89 #[serde(default)]
90 name_servers: Vec<NameServer>,
91 #[serde(default)]
92 options: ResolverOpts,
93 }
94
95 #[derive(serde::Deserialize, Debug)]
96 #[serde(untagged)]
97 #[serde(deny_unknown_fields)]
98 enum NameServer {
99 Ip(String),
100 Detailed {
101 socket_addr: String,
102 #[serde(default)]
103 protocol: Protocol,
104 #[serde(default)]
105 trust_negative_responses: bool,
106 #[serde(default)]
107 bind_addr: Option<String>,
108 },
109 }
110
111 dns_mod.set(
112 "configure_resolver",
113 lua.create_function(move |lua, config: mlua::Value| {
114 let config: DnsConfig = lua.from_value(config)?;
115
116 let mut r_config = ResolverConfig::new();
117 if let Some(dom) = config.domain {
118 r_config.set_domain(
119 Name::from_str_relaxed(&dom)
120 .with_context(|| format!("domain: '{dom}'"))
121 .map_err(any_err)?,
122 );
123 }
124 for s in config.search {
125 let name = Name::from_str_relaxed(&s)
126 .with_context(|| format!("search: '{s}'"))
127 .map_err(any_err)?;
128 r_config.add_search(name);
129 }
130
131 for ns in config.name_servers {
132 r_config.add_name_server(match ns {
133 NameServer::Ip(ip) => {
134 let ip: SocketAddr = ip
135 .parse()
136 .with_context(|| format!("name server: '{ip}'"))
137 .map_err(any_err)?;
138 NameServerConfig::new(ip, Protocol::Udp)
139 }
140 NameServer::Detailed {
141 socket_addr,
142 protocol,
143 trust_negative_responses,
144 bind_addr,
145 } => {
146 let ip: SocketAddr = socket_addr
147 .parse()
148 .with_context(|| format!("name server: '{socket_addr}'"))
149 .map_err(any_err)?;
150 let mut c = NameServerConfig::new(ip, protocol);
151
152 c.trust_negative_responses = trust_negative_responses;
153
154 if let Some(bind) = bind_addr {
155 let addr: SocketAddr = bind
156 .parse()
157 .with_context(|| {
158 format!("name server: '{socket_addr}' bind_addr: '{bind}'")
159 })
160 .map_err(any_err)?;
161 c.bind_addr.replace(addr);
162 }
163
164 c
165 }
166 });
167 }
168
169 let mut builder =
170 TokioResolver::builder_with_config(r_config, TokioConnectionProvider::default());
171 *builder.options_mut() = config.options;
172 dns_resolver::reconfigure_resolver(HickoryResolver::from(builder.build()));
173
174 Ok(())
175 })?,
176 )?;
177
178 dns_mod.set(
179 "configure_unbound_resolver",
180 lua.create_function(move |lua, config: mlua::Value| {
181 let config: DnsConfig = lua.from_value(config)?;
182
183 let context = libunbound::Context::new().map_err(any_err)?;
184
185 for ns in config.name_servers {
186 let addr = match ns {
187 NameServer::Ip(ip) => ip
188 .parse()
189 .with_context(|| format!("name server: '{ip}'"))
190 .map_err(any_err)?,
191 NameServer::Detailed { socket_addr, .. } => socket_addr
192 .parse()
193 .with_context(|| format!("name server: '{socket_addr}'"))
194 .map_err(any_err)?,
195 };
196 context
197 .set_forward(Some(addr))
198 .context("set_forward")
199 .map_err(any_err)?;
200 }
201
202 if config.options.validate {
206 context
207 .add_builtin_trust_anchors()
208 .context("add_builtin_trust_anchors")
209 .map_err(any_err)?;
210 }
211 if matches!(
212 config.options.use_hosts_file,
213 ResolveHosts::Always | ResolveHosts::Auto
214 ) {
215 context
216 .load_hosts(None)
217 .context("load_hosts")
218 .map_err(any_err)?;
219 }
220
221 let context = context
222 .into_async()
223 .context("make async resolver context")
224 .map_err(any_err)?;
225
226 dns_resolver::reconfigure_resolver(UnboundResolver::from(context));
227
228 Ok(())
229 })?,
230 )?;
231
232 dns_mod.set(
233 "configure_test_resolver",
234 lua.create_function(move |_lua, zones: Vec<String>| {
235 let mut resolver = TestResolver::default();
236 for zone in &zones {
237 resolver = resolver.with_zone(zone);
238 }
239
240 dns_resolver::reconfigure_resolver(resolver);
241 Ok(())
242 })?,
243 )?;
244
245 Ok(())
246}