1use anyhow::Context;
2use config::{any_err, get_or_create_sub_module, serialize_options};
3use dns_resolver::{
4 get_resolver, ptr_host, resolve_a_or_aaaa, reverse_ip, set_mx_concurrency_limit,
5 set_mx_negative_cache_ttl, set_mx_timeout, AggregateResolver, HickoryResolver, MailExchanger,
6 Resolver, TestResolver, UnboundResolver,
7};
8use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, ResolverOpts};
9use hickory_resolver::name_server::TokioConnectionProvider;
10use hickory_resolver::proto::xfer::Protocol;
11use hickory_resolver::{Name, TokioResolver};
12use kumo_address::host_or_socket::HostOrSocketAddress;
13use mlua::{Lua, LuaSerdeExt, Value};
14use parking_lot::Mutex;
15use std::collections::HashMap;
16use std::net::{IpAddr, SocketAddr};
17use std::str::FromStr;
18use std::sync::{Arc, LazyLock};
19use std::time::Duration;
20
21static RESOLVERS: LazyLock<Mutex<HashMap<String, Arc<Box<dyn Resolver>>>>> =
22 LazyLock::new(|| Mutex::new(HashMap::new()));
23
24pub fn register(lua: &Lua) -> anyhow::Result<()> {
25 let dns_mod = get_or_create_sub_module(lua, "dns")?;
26
27 dns_mod.set(
28 "lookup_mx",
29 lua.create_async_function(|lua, domain: String| async move {
30 let mx = MailExchanger::resolve(&domain).await.map_err(any_err)?;
31 Ok(lua.to_value_with(&*mx, serialize_options()))
32 })?,
33 )?;
34
35 dns_mod.set(
36 "set_mx_concurrency_limit",
37 lua.create_function(move |_lua, limit: usize| {
38 set_mx_concurrency_limit(limit);
39 Ok(())
40 })?,
41 )?;
42
43 dns_mod.set(
44 "set_mx_timeout",
45 lua.create_function(move |lua, duration: Value| {
46 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
47 set_mx_timeout(duration.into_inner()).map_err(any_err)
48 })?,
49 )?;
50
51 dns_mod.set(
52 "set_mx_negative_cache_ttl",
53 lua.create_function(move |lua, duration: Value| {
54 let duration: duration_serde::Wrap<Duration> = lua.from_value(duration)?;
55 set_mx_negative_cache_ttl(duration.into_inner()).map_err(any_err)
56 })?,
57 )?;
58
59 fn get_resolver_instance(
60 opt_resolver_name: &Option<String>,
61 ) -> anyhow::Result<Arc<Box<dyn Resolver>>> {
62 if let Some(name) = opt_resolver_name {
63 return RESOLVERS
64 .lock()
65 .get(name)
66 .cloned()
67 .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"));
68 }
69
70 Ok(get_resolver())
71 }
72
73 fn get_opt_resolver(
74 opt_resolver_name: &Option<String>,
75 ) -> anyhow::Result<Option<Arc<Box<dyn Resolver>>>> {
76 if let Some(name) = opt_resolver_name {
77 let r = RESOLVERS
78 .lock()
79 .get(name)
80 .cloned()
81 .ok_or_else(|| anyhow::anyhow!("resolver {name} is not defined"))?;
82 Ok(Some(r))
83 } else {
84 Ok(None)
85 }
86 }
87
88 dns_mod.set(
89 "ptr_host",
90 lua.create_function(move |_lua, ip: String| {
91 let ip: IpAddr = ip.parse().map_err(any_err)?;
92 Ok(ptr_host(ip))
93 })?,
94 )?;
95
96 dns_mod.set(
97 "reverse_ip",
98 lua.create_function(move |_lua, ip: String| {
99 let ip: IpAddr = ip.parse().map_err(any_err)?;
100 Ok(reverse_ip(ip))
101 })?,
102 )?;
103
104 dns_mod.set(
105 "rbl_lookup",
106 lua.create_async_function(
107 |_lua, (ip_str, bl_domain, opt_resolver_name): (String, String, Option<String>)| async move {
108 let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
109
110 let address: HostOrSocketAddress = ip_str.parse().map_err(any_err)?;
111 let reversed_ip = reverse_ip(address.ip().ok_or_else(||mlua::Error::external(format!("{ip_str} is not a valid IpAddr or SocketAddr")))?);
112 let name = format!("{reversed_ip}.{bl_domain}.");
113
114 let answers = resolver.resolve_ip(&name).await.map_err(any_err)?;
115 match answers.first() {
116 Some(ip) => {
117 let txt = resolver.resolve_txt(&name).await.map(|a| a.as_txt().join("")).ok();
118 Ok((Some(ip.to_string()), txt))
119 }
120 None => {
121 Ok((None, None))
122 }
123 }
124 },
125 )?,
126 )?;
127
128 dns_mod.set(
129 "lookup_ptr",
130 lua.create_async_function(
131 |lua, (ip_str, opt_resolver_name): (String, Option<String>)| async move {
132 let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
133 let addr = std::net::IpAddr::from_str(&ip_str).map_err(any_err)?;
134 let answer = resolver.resolve_ptr(addr).await.map_err(any_err)?;
135 Ok(lua.to_value_with(&*answer, serialize_options()))
136 },
137 )?,
138 )?;
139
140 dns_mod.set(
141 "lookup_txt",
142 lua.create_async_function(
143 |_lua, (domain, opt_resolver_name): (String, Option<String>)| async move {
144 let resolver = get_resolver_instance(&opt_resolver_name).map_err(any_err)?;
145 let answer = resolver.resolve_txt(&domain).await.map_err(any_err)?;
146 Ok(answer.as_txt())
147 },
148 )?,
149 )?;
150
151 dns_mod.set(
152 "lookup_addr",
153 lua.create_async_function(
154 |_lua, (domain, opt_resolver_name): (String, Option<String>)| async move {
155 let opt_resolver = get_opt_resolver(&opt_resolver_name).map_err(any_err)?;
156 let result = resolve_a_or_aaaa(&domain, opt_resolver.as_ref().map(|r| &***r))
157 .await
158 .map_err(any_err)?;
159 let result: Vec<String> = result
160 .into_iter()
161 .map(|item| item.addr.to_string())
162 .collect();
163 Ok(result)
164 },
165 )?,
166 )?;
167
168 #[derive(serde::Deserialize, Debug)]
169 #[serde(deny_unknown_fields)]
170 struct TestResolverConfig {
171 zones: Vec<String>,
172 }
173
174 impl TestResolverConfig {
175 fn make_resolver(&self) -> anyhow::Result<TestResolver> {
176 let mut resolver = TestResolver::default();
177
178 for zone in &self.zones {
179 resolver = resolver
180 .with_zone(zone)
181 .map_err(|err| anyhow::anyhow!("{err}"))?;
182 }
183
184 Ok(resolver)
185 }
186 }
187
188 #[derive(serde::Deserialize, Debug)]
189 enum KumoResolverConfig {
190 Hickory(DnsConfig),
191 HickorySystemConfig,
192 Unbound(DnsConfig),
193 Test(TestResolverConfig),
194 Aggregate(Vec<KumoResolverConfig>),
195 }
196
197 impl KumoResolverConfig {
198 fn make_resolver(&self) -> anyhow::Result<Box<dyn Resolver>> {
199 match self {
200 Self::Hickory(config) => Ok(Box::new(config.make_hickory()?)),
201 Self::HickorySystemConfig => Ok(Box::new(HickoryResolver::new()?)),
202 Self::Unbound(config) => Ok(Box::new(config.make_unbound()?)),
203 Self::Test(config) => Ok(Box::new(config.make_resolver()?)),
204 Self::Aggregate(config) => {
205 let mut resolver = AggregateResolver::new();
206 for c in config {
207 resolver.push_resolver(c.make_resolver()?);
208 }
209 Ok(Box::new(resolver))
210 }
211 }
212 }
213 }
214
215 #[derive(serde::Deserialize, Debug)]
216 #[serde(deny_unknown_fields)]
217 struct DnsConfig {
218 #[serde(default)]
219 domain: Option<String>,
220 #[serde(default)]
221 search: Vec<String>,
222 #[serde(default)]
223 name_servers: Vec<NameServer>,
224 #[serde(default)]
225 options: ResolverOpts,
226 }
227
228 impl DnsConfig {
229 fn make_hickory(&self) -> anyhow::Result<HickoryResolver> {
230 let mut config = ResolverConfig::new();
231 if let Some(dom) = &self.domain {
232 config.set_domain(
233 Name::from_str_relaxed(&dom).with_context(|| format!("domain: '{dom}'"))?,
234 );
235 }
236 for s in &self.search {
237 let name = Name::from_str_relaxed(&s).with_context(|| format!("search: '{s}'"))?;
238 config.add_search(name);
239 }
240
241 for ns in &self.name_servers {
242 config.add_name_server(match ns {
243 NameServer::Ip(ip) => {
244 let ip: SocketAddr =
245 ip.parse().with_context(|| format!("name server: '{ip}'"))?;
246 NameServerConfig::new(ip, Protocol::Udp)
247 }
248 NameServer::Detailed {
249 socket_addr,
250 protocol,
251 trust_negative_responses,
252 bind_addr,
253 } => {
254 let ip: SocketAddr = socket_addr
255 .parse()
256 .with_context(|| format!("name server: '{socket_addr}'"))?;
257 let mut c = NameServerConfig::new(ip, protocol.clone());
258
259 c.trust_negative_responses = *trust_negative_responses;
260
261 if let Some(bind) = bind_addr {
262 let addr: SocketAddr = bind.parse().with_context(|| {
263 format!("name server: '{socket_addr}' bind_addr: '{bind}'")
264 })?;
265 c.bind_addr.replace(addr);
266 }
267
268 c
269 }
270 });
271 }
272
273 let mut builder =
274 TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
275 *builder.options_mut() = self.options.clone();
276 Ok(HickoryResolver::from(builder.build()))
277 }
278
279 fn make_unbound(&self) -> anyhow::Result<UnboundResolver> {
280 let context = libunbound::Context::new()?;
281
282 for ns in &self.name_servers {
283 let addr = match ns {
284 NameServer::Ip(ip) => {
285 ip.parse().with_context(|| format!("name server: '{ip}'"))?
286 }
287 NameServer::Detailed { socket_addr, .. } => socket_addr
288 .parse()
289 .with_context(|| format!("name server: '{socket_addr}'"))?,
290 };
291 context.set_forward(Some(addr)).context("set_forward")?;
292 }
293
294 if self.options.validate {
298 context
299 .add_builtin_trust_anchors()
300 .context("add_builtin_trust_anchors")?;
301 }
302 if matches!(
303 self.options.use_hosts_file,
304 ResolveHosts::Always | ResolveHosts::Auto
305 ) {
306 context.load_hosts(None).context("load_hosts")?;
307 }
308
309 let context = context
310 .into_async()
311 .context("make async resolver context")?;
312
313 Ok(UnboundResolver::from(context))
314 }
315 }
316
317 #[derive(serde::Deserialize, Debug)]
318 #[serde(untagged)]
319 #[serde(deny_unknown_fields)]
320 enum NameServer {
321 Ip(String),
322 Detailed {
323 socket_addr: String,
324 #[serde(default)]
325 protocol: Protocol,
326 #[serde(default)]
327 trust_negative_responses: bool,
328 #[serde(default)]
329 bind_addr: Option<String>,
330 },
331 }
332
333 dns_mod.set(
334 "configure_resolver",
335 lua.create_function(move |lua, config: mlua::Value| {
336 match lua.from_value::<KumoResolverConfig>(config.clone()) {
337 Ok(config) => {
338 let resolver = config.make_resolver().map_err(any_err)?;
339 dns_resolver::reconfigure_resolver(resolver);
340 Ok(())
341 }
342 Err(err1) => match lua.from_value::<DnsConfig>(config) {
343 Ok(config) => {
344 let resolver = config.make_hickory().map_err(any_err)?;
345 dns_resolver::reconfigure_resolver(resolver);
346 Ok(())
347 }
348 Err(err2) => {
349 Err(mlua::Error::external(format!("failed to parse config as either KumoResolverConfig ({err1:#}) or DnsConfig ({err2:#})")))
350 }
351 }
352 }
353
354 })?,
355 )?;
356
357 dns_mod.set(
358 "define_resolver",
359 lua.create_function(move |lua, (name, config): (String, mlua::Value)| {
360 let config = lua
361 .from_value::<KumoResolverConfig>(config.clone())
362 .map_err(any_err)?;
363 let resolver = config.make_resolver().map_err(any_err)?;
364
365 RESOLVERS.lock().insert(name, resolver.into());
366
367 Ok(())
368 })?,
369 )?;
370
371 dns_mod.set(
372 "configure_unbound_resolver",
373 lua.create_function(move |lua, config: mlua::Value| {
374 let config: DnsConfig = lua.from_value(config)?;
375 let resolver = config.make_unbound().map_err(any_err)?;
376 dns_resolver::reconfigure_resolver(resolver);
377 Ok(())
378 })?,
379 )?;
380
381 dns_mod.set(
382 "configure_test_resolver",
383 lua.create_function(move |_lua, zones: Vec<String>| {
384 let config = TestResolverConfig { zones };
385 let resolver = config.make_resolver().map_err(any_err)?;
386 dns_resolver::reconfigure_resolver(resolver);
387 Ok(())
388 })?,
389 )?;
390
391 Ok(())
392}