1use async_trait::async_trait;
2use hickory_proto::ProtoErrorKind;
3use hickory_resolver::proto::op::response_code::ResponseCode;
4use hickory_resolver::proto::rr::rdata::{A, AAAA, MX, PTR, TXT};
5#[cfg(feature = "unbound")]
6use hickory_resolver::proto::rr::DNSClass;
7use hickory_resolver::proto::rr::{LowerName, RData, RecordData, RecordSet, RecordType, RrKey};
8use hickory_resolver::proto::serialize::txt::Parser;
9use hickory_resolver::{Name, ResolveError, TokioResolver};
10#[cfg(feature = "unbound")]
11use libunbound::{AsyncContext, Context};
12use std::collections::BTreeMap;
13use std::fmt;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::time::{Duration, Instant};
17use thiserror::Error;
18
19pub struct IpDisplay {
20 pub ip: IpAddr,
21 pub reverse: bool,
22}
23
24impl fmt::Display for IpDisplay {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self.ip {
27 IpAddr::V4(v4) => {
28 let mut bytes = v4.octets();
29 if self.reverse {
30 bytes.reverse();
31 }
32 let mut first = true;
33 for byte in bytes {
34 if !first {
35 f.write_str(".")?;
36 }
37 write!(f, "{byte}")?;
38 first = false;
39 }
40 Ok(())
41 }
42 IpAddr::V6(v6) => {
43 let mut bytes = v6.octets();
44 if self.reverse {
45 bytes.reverse();
46 }
47 let mut first = true;
48 for byte in bytes {
49 if !first {
50 f.write_str(".")?;
51 }
52 let (upper, lower) = (byte >> 4, byte & 0xf);
53 if self.reverse {
54 write!(f, "{lower:x}.{upper:x}")?;
55 } else {
56 write!(f, "{upper:x}.{lower:x}")?;
57 }
58 first = false;
59 }
60 Ok(())
61 }
62 }
63 }
64}
65
66pub fn ptr_host(ip: IpAddr) -> String {
67 let mut out = IpDisplay { ip, reverse: true }.to_string();
68 out.push_str(match ip {
69 IpAddr::V4(_) => ".in-addr.arpa",
70 IpAddr::V6(_) => ".ip6.arpa",
71 });
72 out
73}
74
75#[derive(Debug)]
76pub struct Answer {
77 pub canon_name: Option<String>,
78 pub records: Vec<RData>,
79 pub nxdomain: bool,
80 pub secure: bool,
81 pub bogus: bool,
82 pub why_bogus: Option<String>,
83 pub expires: Instant,
84 pub response_code: ResponseCode,
85}
86
87impl Answer {
88 pub fn as_txt(&self) -> Vec<String> {
89 let mut result = vec![];
90 for r in &self.records {
91 if let Some(txt) = r.as_txt() {
92 let mut joined = String::new();
93 for t in txt.iter() {
94 joined.push_str(&String::from_utf8_lossy(t));
95 }
96 result.push(joined);
97 }
98 }
99 result
100 }
101
102 pub fn as_addr(&self) -> Vec<IpAddr> {
103 let mut result = vec![];
104 for r in &self.records {
105 if let Some(a) = r.as_a() {
106 result.push(a.0.into());
107 } else if let Some(a) = r.as_aaaa() {
108 result.push(a.0.into());
109 }
110 }
111 result
112 }
113}
114
115#[derive(Clone, Debug, Error, PartialEq)]
116pub enum DnsError {
117 #[error("invalid DNS name: {0}")]
118 InvalidName(String),
119 #[error("DNS: {0}")]
120 ResolveFailed(String),
121}
122
123impl DnsError {
124 pub(crate) fn from_resolve(name: &impl fmt::Display, err: ResolveError) -> Self {
125 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
126 }
127}
128
129#[async_trait]
130pub trait Resolver: Send + Sync + 'static {
131 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError>;
132
133 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError>;
134
135 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError>;
136
137 async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
138 let name = Name::from_utf8(name)
139 .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
140 self.resolve(name, RecordType::TXT).await
141 }
142
143 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError>;
144}
145
146#[derive(Debug, Default)]
147pub struct TestResolver {
148 records: BTreeMap<Name, BTreeMap<RrKey, RecordSet>>,
149}
150
151impl TestResolver {
152 pub fn with_zone(mut self, zone: &str) -> Self {
153 let (mut name, records) = Parser::new(zone, None, None).parse().unwrap();
154 name.set_fqdn(true);
158 let fqdn_records = records
159 .into_iter()
160 .map(|(key, value)| {
161 if key.name().is_fqdn() {
162 (key, value)
163 } else {
164 let mut name: Name = key.name().into();
165 name.set_fqdn(true);
166 (RrKey::new(LowerName::new(&name), key.record_type), value)
167 }
168 })
169 .collect();
170 self.records.insert(name, fqdn_records);
171 self
172 }
173
174 pub fn with_txt(self, domain: &str, value: String) -> Self {
175 self.with_txt_multiple(domain, vec![value])
176 }
177
178 pub fn with_txt_multiple(mut self, domain: &str, value: Vec<String>) -> Self {
180 let fqdn = format!("{}.", domain);
181 let authority = Name::from_str(&fqdn).unwrap();
182 let key = RrKey {
183 name: LowerName::from_str(&fqdn).unwrap(),
184 record_type: RecordType::TXT,
185 };
186
187 let mut records = RecordSet::new(authority.clone(), RecordType::TXT, 0);
188 for item in value {
189 records.add_rdata(RData::TXT(TXT::new(vec![item])));
190 }
191 self.records
192 .entry(authority)
193 .or_default()
194 .insert(key, records);
195
196 self
197 }
198
199 fn get(&self, full: &Name, record_type: RecordType) -> Result<Answer, DnsError> {
200 let mut full_fqdn = full.clone();
201 full_fqdn.set_fqdn(true);
202 let mut authority = full_fqdn.clone();
203
204 let records = loop {
205 if let Some(records) = self.records.get(&authority) {
206 break records;
207 };
208
209 if authority.num_labels() > 1 {
210 authority = authority.base_name();
211 continue;
212 }
213
214 return Ok(Answer {
215 canon_name: None,
216 records: vec![],
217 nxdomain: true,
218 secure: false,
219 bogus: false,
220 why_bogus: None,
221 expires: Instant::now() + Duration::from_secs(60),
222 response_code: ResponseCode::NXDomain,
223 });
224 };
225
226 let records = records.get(&RrKey {
227 name: LowerName::from(&full_fqdn),
228 record_type,
229 });
230
231 let Some(records) = records else {
232 return Ok(Answer {
233 canon_name: None,
234 records: vec![],
235 nxdomain: false,
236 secure: false,
237 bogus: false,
238 why_bogus: None,
239 expires: Instant::now() + Duration::from_secs(60),
240 response_code: ResponseCode::NoError,
241 });
242 };
243
244 Ok(Answer {
245 canon_name: None,
246 records: records
247 .records_without_rrsigs()
248 .map(|r| r.data().clone())
249 .collect(),
250 nxdomain: false,
251 secure: false,
252 bogus: false,
253 why_bogus: None,
254 expires: Instant::now() + Duration::from_secs(60),
255 response_code: ResponseCode::NoError,
256 })
257 }
258}
259
260#[async_trait]
261impl Resolver for TestResolver {
262 async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
263 let name = Name::from_utf8(full)
264 .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
265
266 let mut values = vec![];
267 let answer = self.get(&name, RecordType::A)?;
268 for record in answer.records {
269 let a = A::try_borrow(&record).unwrap();
270 values.push(IpAddr::V4(a.0));
271 }
272
273 let answer = self.get(&name, RecordType::AAAA)?;
274 for record in answer.records {
275 let a = AAAA::try_borrow(&record).unwrap();
276 values.push(IpAddr::V6(a.0));
277 }
278
279 Ok(values)
280 }
281
282 async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
283 let name = Name::from_utf8(full)
284 .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
285
286 let mut values = vec![];
287 let answer = self.get(&name, RecordType::MX)?;
288 for record in answer.records {
289 let mx = MX::try_borrow(&record).unwrap();
290 values.push(mx.exchange().clone());
291 }
292
293 Ok(values)
294 }
295
296 async fn resolve_txt(&self, full: &str) -> Result<Answer, DnsError> {
297 let name = Name::from_utf8(full)
298 .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
299 self.get(&name, RecordType::TXT)
300 }
301
302 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
303 let name = ptr_host(ip);
304 let name = Name::from_utf8(ptr_host(ip))
305 .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
306
307 let mut values = vec![];
308 let answer = self.get(&name, RecordType::PTR)?;
309 for record in answer.records {
310 match PTR::try_borrow(&record) {
311 Some(ptr) => values.push(ptr.0.clone()),
312 None => {
313 println!("invalid record found for PTR record for {ip}");
314 return Err(DnsError::ResolveFailed(format!(
315 "invalid record found for PTR record for {ip}"
316 )));
317 }
318 };
319 }
320
321 Ok(values)
322 }
323
324 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
325 self.get(&name, rrtype)
326 }
327}
328
329#[cfg(feature = "unbound")]
330pub struct UnboundResolver {
331 cx: AsyncContext,
332}
333
334#[cfg(feature = "unbound")]
335impl UnboundResolver {
336 pub fn new() -> Result<Self, libunbound::Error> {
337 let context = Context::new()?;
339 context.add_builtin_trust_anchors()?;
341 Ok(Self {
342 cx: context.into_async()?,
343 })
344 }
345}
346
347#[cfg(feature = "unbound")]
348#[async_trait]
349impl Resolver for UnboundResolver {
350 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
351 let (a, aaaa) = tokio::join!(
352 self.cx.resolve(host, RecordType::A, DNSClass::IN),
353 self.cx.resolve(host, RecordType::AAAA, DNSClass::IN),
354 );
355
356 let mut records = vec![];
357 match (a, aaaa) {
358 (Ok(a), Ok(aaaa)) => {
359 records.extend(a.rdata().filter_map(|r| match r {
360 Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
361 Err(_) => None,
362 }));
363 records.extend(aaaa.rdata().filter_map(|r| match r {
364 Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
365 Err(_) => None,
366 }));
367 }
368 (Ok(a), Err(_)) => {
369 records.extend(a.rdata().filter_map(|r| match r {
370 Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
371 Err(_) => None,
372 }));
373 }
374 (Err(_), Ok(aaaa)) => {
375 records.extend(aaaa.rdata().filter_map(|r| match r {
376 Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
377 Err(_) => None,
378 }));
379 }
380 (Err(err), Err(_)) => {
381 return Err(DnsError::ResolveFailed(format!(
382 "failed to query DNS for {host}: {err}"
383 )))
384 }
385 }
386
387 Ok(records)
388 }
389
390 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
391 let answer = self
392 .cx
393 .resolve(host, RecordType::A, DNSClass::IN)
394 .await
395 .map_err(|err| {
396 DnsError::ResolveFailed(format!("failed to query DNS for {host}: {err}"))
397 })?;
398
399 Ok(answer
400 .rdata()
401 .filter_map(|r| match r {
402 Ok(r) => r.as_mx().map(|mx| mx.exchange().clone()),
403 Err(_) => None,
404 })
405 .collect())
406 }
407
408 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
409 let name = ptr_host(ip);
410 let answer = self
411 .cx
412 .resolve(&name, RecordType::PTR, DNSClass::IN)
413 .await
414 .map_err(|err| {
415 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
416 })?;
417
418 Ok(answer
419 .rdata()
420 .filter_map(|r| match r {
421 Ok(r) => r.as_ptr().map(|ptr| ptr.0.clone()),
422 Err(_) => None,
423 })
424 .collect())
425 }
426
427 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
428 let name = name.to_ascii();
429 let answer = self
430 .cx
431 .resolve(&name, rrtype, DNSClass::IN)
432 .await
433 .map_err(|err| {
434 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
435 })?;
436
437 let mut records = vec![];
438 for r in answer.rdata() {
439 if let Ok(r) = r {
440 records.push(r);
441 }
442 }
443
444 Ok(Answer {
445 canon_name: answer.canon_name().map(|s| s.to_string()),
446 records,
447 nxdomain: answer.nxdomain(),
448 secure: answer.secure(),
449 bogus: answer.bogus(),
450 why_bogus: answer.why_bogus().map(|s| s.to_string()),
451 response_code: answer.rcode(),
452 expires: Instant::now() + Duration::from_secs(answer.ttl() as u64),
453 })
454 }
455}
456
457#[cfg(feature = "unbound")]
458impl From<AsyncContext> for UnboundResolver {
459 fn from(cx: AsyncContext) -> Self {
460 Self { cx }
461 }
462}
463
464pub struct HickoryResolver {
465 inner: TokioResolver,
466}
467
468impl HickoryResolver {
469 pub fn new() -> Result<Self, hickory_resolver::ResolveError> {
470 Ok(Self {
471 inner: TokioResolver::builder_tokio()?.build(),
472 })
473 }
474}
475
476#[async_trait]
477impl Resolver for HickoryResolver {
478 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
479 let name = Name::from_utf8(host)
480 .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
481
482 self.inner
483 .lookup_ip(name)
484 .await
485 .map_err(|err| DnsError::from_resolve(&host, err))?
486 .into_iter()
487 .map(Ok)
488 .collect()
489 }
490
491 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
492 let name = Name::from_utf8(host)
493 .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
494
495 self.inner
496 .mx_lookup(name)
497 .await
498 .map_err(|err| DnsError::from_resolve(&host, err))?
499 .into_iter()
500 .map(|mx| Ok(mx.exchange().clone()))
501 .collect()
502 }
503
504 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
505 self.inner
506 .reverse_lookup(ip)
507 .await
508 .map_err(|err| DnsError::from_resolve(&ip, err))?
509 .into_iter()
510 .map(|ptr| Ok(ptr.0))
511 .collect()
512 }
513
514 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
515 match self.inner.lookup(name.clone(), rrtype).await {
516 Ok(result) => {
517 let expires = result.valid_until();
518 let records = result.iter().cloned().collect();
519 Ok(Answer {
520 canon_name: None,
521 records,
522 nxdomain: false,
523 secure: false,
524 bogus: false,
525 why_bogus: None,
526 expires,
527 response_code: ResponseCode::NoError,
528 })
529 }
530 Err(err) => match err.proto().map(|err| err.kind()) {
531 Some(ProtoErrorKind::NoRecordsFound {
532 negative_ttl,
533 response_code,
534 ..
535 }) => Ok(Answer {
536 canon_name: None,
537 records: vec![],
538 nxdomain: *response_code == ResponseCode::NXDomain,
539 secure: false,
540 bogus: false,
541 why_bogus: None,
542 response_code: *response_code,
543 expires: Instant::now()
544 + Duration::from_secs(negative_ttl.unwrap_or(60) as u64),
545 }),
546 _ => Err(DnsError::from_resolve(&name, err)),
547 },
548 }
549 }
550}
551
552impl From<TokioResolver> for HickoryResolver {
553 fn from(inner: TokioResolver) -> Self {
554 Self { inner }
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::ptr_host;
561 use std::net::{Ipv4Addr, Ipv6Addr};
562 use std::str::FromStr;
563
564 #[test]
565 fn test_ptr_host() {
566 assert_eq!(
567 ptr_host(Ipv4Addr::new(192, 0, 2, 1).into()),
568 "1.2.0.192.in-addr.arpa"
569 );
570 assert_eq!(
571 ptr_host(Ipv6Addr::from_str("2001:db8::1").unwrap().into()),
572 "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa"
573 );
574 }
575}