1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4 get_resolver, resolve_a_or_aaaa, set_mx_concurrency_limit, set_mx_negative_cache_ttl,
5 set_mx_timeout, HickoryResolver, MailExchanger, TestResolver, UnboundResolver,
6};
7use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
8use hickory_resolver::name_server::TokioConnectionProvider;
9use hickory_resolver::proto::xfer::Protocol;
10use hickory_resolver::{Name, TokioResolver};
11use mlua::{Lua, LuaSerdeExt, Value};
12use std::net::SocketAddr;
13use std::time::Duration;
14
15pub fn register(lua: &Lua) -> anyhow::Result<()> {
16 let dns_mod = get_or_create_sub_module(lua, "dns")?;
17
18 dns_mod.set(
19 "lookup_mx",
20 lua.create_async_function(|lua, domain: String| async move {
21 let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
22 Ok(lua.to_value_with(&*mx, serialize_options()))
23 })?,
24 )?;
25
26 dns_mod.set(
27 "set_mx_concurrency_limit",
28 lua.create_function(move |_lua, limit: usize| {
29 set_mx_concurrency_limit(limit);
30 Ok(())
31 })?,
32 )?;
33
34 dns_mod.set(
35 "set_mx_timeout",
36 lua.create_function(move |lua, duration: Value| {
37 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
38 set_mx_timeout(duration.into_inner()).map_err(any_err)
39 })?,
40 )?;
41
42 dns_mod.set(
43 "set_mx_negative_cache_ttl",
44 lua.create_function(move |lua, duration: Value| {
45 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
46 set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
47 })?,
48 )?;
49
50 dns_mod.set(
51 "lookup_txt",
52 lua.create_async_function(|_lua, domain: String| async move {
53 let resolver = get_resolver();
54 let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
55 Ok(answer.as_txt())
56 })?,
57 )?;
58
59 dns_mod.set(
60 "lookup_addr",
61 lua.create_async_function(|_lua, domain: String| async move {
62 let result = resolve_a_or_aaaa(&domain).await.map_err(any_err)?;
63 let result: Vec<String> = result
64 .into_iter()
65 .map(|item| item.addr.to_string())
66 .collect();
67 Ok(result)
68 })?,
69 )?;
70
71 #[derive(serde::Deserialize, Debug)]
72 #[serde(deny_unknown_fields)]
73 struct DnsConfig {
74 #[serde(default)]
75 domain: Option<String>,
76 #[serde(default)]
77 search: Vec<String>,
78 #[serde(default)]
79 name_servers: Vec<NameServer>,
80 #[serde(default)]
81 options: ResolverOpts,
82 }
83
84 #[derive(serde::Deserialize, Debug)]
85 #[serde(untagged)]
86 #[serde(deny_unknown_fields)]
87 enum NameServer {
88 Ip(String),
89 Detailed {
90 socket_addr: String,
91 #[serde(default)]
92 protocol: Protocol,
93 #[serde(default)]
94 trust_negative_responses: bool,
95 #[serde(default)]
96 bind_addr: Option<String>,
97 },
98 }
99
100 dns_mod.set(
101 "configure_resolver",
102 lua.create_function(move |lua, config: mlua::Value| {
103 let config: DnsConfig = lua.from_value(config)?;
104
105 let mut r_config = ResolverConfig::new();
106 if let Some(dom) = config.domain {
107 r_config.set_domain(
108 Name::from_str_relaxed(&dom)
109 .with_context(|| format!("domain: '{dom}'"))
110 .map_err(any_err)?,
111 );
112 }
113 for s in config.search {
114 let name = Name::from_str_relaxed(&s)
115 .with_context(|| format!("search: '{s}'"))
116 .map_err(any_err)?;
117 r_config.add_search(name);
118 }
119
120 for ns in config.name_servers {
121 r_config.add_name_server(match ns {
122 NameServer::Ip(ip) => {
123 let ip: SocketAddr = ip
124 .parse()
125 .with_context(|| format!("name server: '{ip}'"))
126 .map_err(any_err)?;
127 NameServerConfig::new(ip, Protocol::Udp)
128 }
129 NameServer::Detailed {
130 socket_addr,
131 protocol,
132 trust_negative_responses,
133 bind_addr,
134 } => {
135 let ip: SocketAddr = socket_addr
136 .parse()
137 .with_context(|| format!("name server: '{socket_addr}'"))
138 .map_err(any_err)?;
139 let mut c = NameServerConfig::new(ip, protocol);
140
141 c.trust_negative_responses = trust_negative_responses;
142
143 if let Some(bind) = bind_addr {
144 let addr: SocketAddr = bind
145 .parse()
146 .with_context(|| {
147 format!("name server: '{socket_addr}' bind_addr: '{bind}'")
148 })
149 .map_err(any_err)?;
150 c.bind_addr.replace(addr);
151 }
152
153 c
154 }
155 });
156 }
157
158 let mut builder =
159 TokioResolver::builder_with_config(r_config, TokioConnectionProvider::default());
160 *builder.options_mut() = config.options;
161 dns_resolver::reconfigure_resolver(HickoryResolver::from(builder.build()));
162
163 Ok(())
164 })?,
165 )?;
166
167 dns_mod.set(
168 "configure_unbound_resolver",
169 lua.create_function(move |lua, config: mlua::Value| {
170 let config: DnsConfig = lua.from_value(config)?;
171
172 let context = libunbound::Context::new().map_err(any_err)?;
173
174 for ns in config.name_servers {
175 let addr = match ns {
176 NameServer::Ip(ip) => ip
177 .parse()
178 .with_context(|| format!("name server: '{ip}'"))
179 .map_err(any_err)?,
180 NameServer::Detailed { socket_addr, .. } => socket_addr
181 .parse()
182 .with_context(|| format!("name server: '{socket_addr}'"))
183 .map_err(any_err)?,
184 };
185 context
186 .set_forward(Some(addr))
187 .context("set_forward")
188 .map_err(any_err)?;
189 }
190
191 if config.options.validate {
195 context
196 .add_builtin_trust_anchors()
197 .context("add_builtin_trust_anchors")
198 .map_err(any_err)?;
199 }
200 if matches!(
201 config.options.use_hosts_file,
202 ResolveHosts::Always | ResolveHosts::Auto
203 ) {
204 context
205 .load_hosts(None)
206 .context("load_hosts")
207 .map_err(any_err)?;
208 }
209
210 let context = context
211 .into_async()
212 .context("make async resolver context")
213 .map_err(any_err)?;
214
215 dns_resolver::reconfigure_resolver(UnboundResolver::from(context));
216
217 Ok(())
218 })?,
219 )?;
220
221 dns_mod.set(
222 "configure_test_resolver",
223 lua.create_function(move |_lua, zones: Vec<String>| {
224 let mut resolver = TestResolver::default();
225 for zone in &zones {
226 resolver = resolver.with_zone(zone);
227 }
228
229 dns_resolver::reconfigure_resolver(resolver);
230 Ok(())
231 })?,
232 )?;
233
234 Ok(())
235}