1use async_trait::async_trait;
2use hickory_proto::ProtoErrorKind;
3use hickory_resolver::proto::op::response_code::ResponseCode;
4use hickory_resolver::proto::rr::rdata::{A, AAAA, MX, PTR, TXT};
5#[cfg(feature = "unbound")]
6use hickory_resolver::proto::rr::DNSClass;
7use hickory_resolver::proto::rr::{LowerName, RData, RecordData, RecordSet, RecordType, RrKey};
8use hickory_resolver::proto::serialize::txt::Parser;
9use hickory_resolver::{Name, ResolveError, TokioResolver};
10#[cfg(feature = "unbound")]
11use libunbound::{AsyncContext, Context};
12use std::collections::BTreeMap;
13use std::fmt;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::time::{Duration, Instant};
17use thiserror::Error;
18
19pub struct IpDisplay {
20 pub ip: IpAddr,
21 pub reverse: bool,
22}
23
24impl fmt::Display for IpDisplay {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self.ip {
27 IpAddr::V4(v4) => {
28 let mut bytes = v4.octets();
29 if self.reverse {
30 bytes.reverse();
31 }
32 let mut first = true;
33 for byte in bytes {
34 if !first {
35 f.write_str(".")?;
36 }
37 write!(f, "{byte}")?;
38 first = false;
39 }
40 Ok(())
41 }
42 IpAddr::V6(v6) => {
43 let mut bytes = v6.octets();
44 if self.reverse {
45 bytes.reverse();
46 }
47 let mut first = true;
48 for byte in bytes {
49 if !first {
50 f.write_str(".")?;
51 }
52 let (upper, lower) = (byte >> 4, byte & 0xf);
53 if self.reverse {
54 write!(f, "{lower:x}.{upper:x}")?;
55 } else {
56 write!(f, "{upper:x}.{lower:x}")?;
57 }
58 first = false;
59 }
60 Ok(())
61 }
62 }
63 }
64}
65
66pub fn ptr_host(ip: IpAddr) -> String {
67 let mut out = reverse_ip(ip);
68 out.push_str(match ip {
69 IpAddr::V4(_) => ".in-addr.arpa",
70 IpAddr::V6(_) => ".ip6.arpa",
71 });
72 out
73}
74
75pub fn reverse_ip(ip: IpAddr) -> String {
76 IpDisplay { ip, reverse: true }.to_string()
77}
78
79#[derive(Debug)]
80pub struct Answer {
81 pub canon_name: Option<String>,
82 pub records: Vec<RData>,
83 pub nxdomain: bool,
84 pub secure: bool,
85 pub bogus: bool,
86 pub why_bogus: Option<String>,
87 pub expires: Instant,
88 pub response_code: ResponseCode,
89}
90
91impl Answer {
92 pub fn as_txt(&self) -> Vec<String> {
93 let mut result = vec![];
94 for r in &self.records {
95 if let Some(txt) = r.as_txt() {
96 let mut joined = String::new();
97 for t in txt.iter() {
98 joined.push_str(&String::from_utf8_lossy(t));
99 }
100 result.push(joined);
101 }
102 }
103 result
104 }
105
106 pub fn as_addr(&self) -> Vec<IpAddr> {
107 let mut result = vec![];
108 for r in &self.records {
109 if let Some(a) = r.as_a() {
110 result.push(a.0.into());
111 } else if let Some(a) = r.as_aaaa() {
112 result.push(a.0.into());
113 }
114 }
115 result
116 }
117}
118
119#[derive(Clone, Debug, Error, PartialEq)]
120pub enum DnsError {
121 #[error("invalid DNS name: {0}")]
122 InvalidName(String),
123 #[error("DNS: {0}")]
124 ResolveFailed(String),
125}
126
127impl DnsError {
128 pub(crate) fn from_resolve(name: &impl fmt::Display, err: ResolveError) -> Self {
129 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
130 }
131}
132
133#[async_trait]
134pub trait Resolver: Send + Sync + 'static {
135 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError>;
136
137 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError>;
138
139 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError>;
140
141 async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
142 let name = Name::from_str_relaxed(name)
143 .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
144 self.resolve(name, RecordType::TXT).await
145 }
146
147 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError>;
148}
149
150#[async_trait]
151impl Resolver for Box<dyn Resolver> {
152 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
153 (**self).resolve_ip(host).await
154 }
155
156 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
157 (**self).resolve_mx(host).await
158 }
159
160 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
161 (**self).resolve_ptr(ip).await
162 }
163
164 async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
165 (**self).resolve_txt(name).await
166 }
167
168 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
169 (**self).resolve(name, rrtype).await
170 }
171}
172
173#[derive(Debug, Default)]
174pub struct TestResolver {
175 records: BTreeMap<Name, BTreeMap<RrKey, RecordSet>>,
176}
177
178impl TestResolver {
179 pub fn with_zone(mut self, zone: &str) -> Result<Self, String> {
180 let (mut name, records) = Parser::new(zone, None, None)
181 .parse()
182 .map_err(|err| format!("{err:#}"))?;
183 name.set_fqdn(true);
187 let fqdn_records = records
188 .into_iter()
189 .map(|(key, value)| {
190 if key.name().is_fqdn() {
191 (key, value)
192 } else {
193 let mut name: Name = key.name().into();
194 name.set_fqdn(true);
195 (RrKey::new(LowerName::new(&name), key.record_type), value)
196 }
197 })
198 .collect();
199 self.records.insert(name, fqdn_records);
200 Ok(self)
201 }
202
203 pub fn with_txt(self, domain: &str, value: impl Into<String>) -> Self {
204 self.with_txt_multiple(domain, vec![value.into()])
205 }
206
207 pub fn with_txt_multiple(mut self, domain: &str, value: Vec<String>) -> Self {
209 let fqdn = format!("{}.", domain);
210 let authority = Name::from_str_relaxed(&fqdn).unwrap();
211 let key = RrKey {
212 name: LowerName::from_str(&fqdn).unwrap(),
213 record_type: RecordType::TXT,
214 };
215
216 let mut records = RecordSet::new(authority.clone(), RecordType::TXT, 0);
217 for item in value {
218 records.add_rdata(RData::TXT(TXT::new(vec![item])));
219 }
220 self.records
221 .entry(authority)
222 .or_default()
223 .insert(key, records);
224
225 self
226 }
227
228 fn get(&self, full: &Name, record_type: RecordType) -> Result<Answer, DnsError> {
229 let mut full_fqdn = full.clone();
230 full_fqdn.set_fqdn(true);
231 let mut authority = full_fqdn.clone();
232
233 let records = loop {
234 if let Some(records) = self.records.get(&authority) {
235 break records;
236 };
237
238 if authority.num_labels() > 1 {
239 authority = authority.base_name();
240 continue;
241 }
242
243 return Ok(Answer {
244 canon_name: None,
245 records: vec![],
246 nxdomain: true,
247 secure: false,
248 bogus: false,
249 why_bogus: None,
250 expires: Instant::now() + Duration::from_secs(60),
251 response_code: ResponseCode::NXDomain,
252 });
253 };
254
255 let records = records.get(&RrKey {
256 name: LowerName::from(&full_fqdn),
257 record_type,
258 });
259
260 let Some(records) = records else {
261 return Ok(Answer {
262 canon_name: None,
263 records: vec![],
264 nxdomain: false,
265 secure: false,
266 bogus: false,
267 why_bogus: None,
268 expires: Instant::now() + Duration::from_secs(60),
269 response_code: ResponseCode::NoError,
270 });
271 };
272
273 Ok(Answer {
274 canon_name: None,
275 records: records
276 .records_without_rrsigs()
277 .map(|r| r.data().clone())
278 .collect(),
279 nxdomain: false,
280 secure: false,
281 bogus: false,
282 why_bogus: None,
283 expires: Instant::now() + Duration::from_secs(60),
284 response_code: ResponseCode::NoError,
285 })
286 }
287}
288
289#[async_trait]
290impl Resolver for TestResolver {
291 async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
292 let name = Name::from_str_relaxed(full)
293 .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
294
295 let mut values = vec![];
296 let answer = self.get(&name, RecordType::A)?;
297 for record in answer.records {
298 let a = A::try_borrow(&record).unwrap();
299 values.push(IpAddr::V4(a.0));
300 }
301
302 let answer = self.get(&name, RecordType::AAAA)?;
303 for record in answer.records {
304 let a = AAAA::try_borrow(&record).unwrap();
305 values.push(IpAddr::V6(a.0));
306 }
307
308 Ok(values)
309 }
310
311 async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
312 let name = Name::from_str_relaxed(full)
313 .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
314
315 let mut values = vec![];
316 let answer = self.get(&name, RecordType::MX)?;
317 for record in answer.records {
318 let mx = MX::try_borrow(&record).unwrap();
319 values.push(mx.exchange().clone());
320 }
321
322 Ok(values)
323 }
324
325 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
326 let name = ptr_host(ip);
327 let name = Name::from_str_relaxed(ptr_host(ip))
328 .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
329
330 let mut values = vec![];
331 let answer = self.get(&name, RecordType::PTR)?;
332 for record in answer.records {
333 match PTR::try_borrow(&record) {
334 Some(ptr) => values.push(ptr.0.clone()),
335 None => {
336 println!("invalid record found for PTR record for {ip}");
337 return Err(DnsError::ResolveFailed(format!(
338 "invalid record found for PTR record for {ip}"
339 )));
340 }
341 };
342 }
343
344 Ok(values)
345 }
346
347 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
348 self.get(&name, rrtype)
349 }
350}
351
352#[cfg(feature = "unbound")]
353pub struct UnboundResolver {
354 cx: AsyncContext,
355}
356
357#[cfg(feature = "unbound")]
358impl UnboundResolver {
359 pub fn new() -> Result<Self, libunbound::Error> {
360 let context = Context::new()?;
362 context.add_builtin_trust_anchors()?;
364 Ok(Self {
365 cx: context.into_async()?,
366 })
367 }
368}
369
370#[cfg(feature = "unbound")]
371#[async_trait]
372impl Resolver for UnboundResolver {
373 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
374 let (a, aaaa) = tokio::join!(
375 self.cx.resolve(host, RecordType::A, DNSClass::IN),
376 self.cx.resolve(host, RecordType::AAAA, DNSClass::IN),
377 );
378
379 let mut records = vec![];
380 match (a, aaaa) {
381 (Ok(a), Ok(aaaa)) => {
382 records.extend(a.rdata().filter_map(|r| match r {
383 Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
384 Err(_) => None,
385 }));
386 records.extend(aaaa.rdata().filter_map(|r| match r {
387 Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
388 Err(_) => None,
389 }));
390 }
391 (Ok(a), Err(_)) => {
392 records.extend(a.rdata().filter_map(|r| match r {
393 Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
394 Err(_) => None,
395 }));
396 }
397 (Err(_), Ok(aaaa)) => {
398 records.extend(aaaa.rdata().filter_map(|r| match r {
399 Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
400 Err(_) => None,
401 }));
402 }
403 (Err(err), Err(_)) => {
404 return Err(DnsError::ResolveFailed(format!(
405 "failed to query DNS for {host}: {err}"
406 )))
407 }
408 }
409
410 Ok(records)
411 }
412
413 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
414 let answer = self
415 .cx
416 .resolve(host, RecordType::A, DNSClass::IN)
417 .await
418 .map_err(|err| {
419 DnsError::ResolveFailed(format!("failed to query DNS for {host}: {err}"))
420 })?;
421
422 Ok(answer
423 .rdata()
424 .filter_map(|r| match r {
425 Ok(r) => r.as_mx().map(|mx| mx.exchange().clone()),
426 Err(_) => None,
427 })
428 .collect())
429 }
430
431 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
432 let name = ptr_host(ip);
433 let answer = self
434 .cx
435 .resolve(&name, RecordType::PTR, DNSClass::IN)
436 .await
437 .map_err(|err| {
438 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
439 })?;
440
441 Ok(answer
442 .rdata()
443 .filter_map(|r| match r {
444 Ok(r) => r.as_ptr().map(|ptr| ptr.0.clone()),
445 Err(_) => None,
446 })
447 .collect())
448 }
449
450 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
451 let name = name.to_ascii();
452 let answer = self
453 .cx
454 .resolve(&name, rrtype, DNSClass::IN)
455 .await
456 .map_err(|err| {
457 DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
458 })?;
459
460 let mut records = vec![];
461 for r in answer.rdata() {
462 if let Ok(r) = r {
463 records.push(r);
464 }
465 }
466
467 Ok(Answer {
468 canon_name: answer.canon_name().map(|s| s.to_string()),
469 records,
470 nxdomain: answer.nxdomain(),
471 secure: answer.secure(),
472 bogus: answer.bogus(),
473 why_bogus: answer.why_bogus().map(|s| s.to_string()),
474 response_code: answer.rcode(),
475 expires: Instant::now() + Duration::from_secs(answer.ttl() as u64),
476 })
477 }
478}
479
480#[cfg(feature = "unbound")]
481impl From<AsyncContext> for UnboundResolver {
482 fn from(cx: AsyncContext) -> Self {
483 Self { cx }
484 }
485}
486
487pub struct HickoryResolver {
488 inner: TokioResolver,
489}
490
491impl HickoryResolver {
492 pub fn new() -> Result<Self, hickory_resolver::ResolveError> {
493 Ok(Self {
494 inner: TokioResolver::builder_tokio()?.build(),
495 })
496 }
497}
498
499#[async_trait]
500impl Resolver for HickoryResolver {
501 async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
502 let name = Name::from_str_relaxed(host)
503 .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
504
505 match self.inner.lookup_ip(name.clone()).await {
506 Ok(result) => Ok(result.into_iter().collect()),
507 Err(err) => match err.proto().map(|err| err.kind()) {
508 Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
509 _ => Err(DnsError::from_resolve(&name, err)),
510 },
511 }
512 }
513
514 async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
515 let name = Name::from_str_relaxed(host)
516 .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
517
518 match self.inner.mx_lookup(name.clone()).await {
519 Ok(result) => Ok(result.into_iter().map(|mx| mx.exchange().clone()).collect()),
520 Err(err) => match err.proto().map(|err| err.kind()) {
521 Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
522 _ => Err(DnsError::from_resolve(&name, err)),
523 },
524 }
525 }
526
527 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
528 match self.inner.reverse_lookup(ip).await {
529 Ok(result) => Ok(result.into_iter().map(|ptr| ptr.0).collect()),
530 Err(err) => match err.proto().map(|err| err.kind()) {
531 Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
532 _ => Err(DnsError::from_resolve(&ip, err)),
533 },
534 }
535 }
536
537 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
538 match self.inner.lookup(name.clone(), rrtype).await {
539 Ok(result) => {
540 let expires = result.valid_until();
541 let records = result.iter().cloned().collect();
542 Ok(Answer {
543 canon_name: None,
544 records,
545 nxdomain: false,
546 secure: false,
547 bogus: false,
548 why_bogus: None,
549 expires,
550 response_code: ResponseCode::NoError,
551 })
552 }
553 Err(err) => match err.proto().map(|err| err.kind()) {
554 Some(ProtoErrorKind::NoRecordsFound {
555 negative_ttl,
556 response_code,
557 ..
558 }) => Ok(Answer {
559 canon_name: None,
560 records: vec![],
561 nxdomain: *response_code == ResponseCode::NXDomain,
562 secure: false,
563 bogus: false,
564 why_bogus: None,
565 response_code: *response_code,
566 expires: Instant::now()
567 + Duration::from_secs(negative_ttl.unwrap_or(60) as u64),
568 }),
569 _ => Err(DnsError::from_resolve(&name, err)),
570 },
571 }
572 }
573}
574
575impl From<TokioResolver> for HickoryResolver {
576 fn from(inner: TokioResolver) -> Self {
577 Self { inner }
578 }
579}
580
581pub struct AggregateResolver {
589 resolvers: Vec<Box<dyn Resolver>>,
590}
591
592impl AggregateResolver {
593 pub fn new() -> Self {
594 Self { resolvers: vec![] }
595 }
596
597 pub fn push_resolver(&mut self, resolver: Box<dyn Resolver>) {
598 self.resolvers.push(resolver);
599 }
600}
601
602#[async_trait]
603impl Resolver for AggregateResolver {
604 async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
605 let mut errors = vec![];
606 for resolver in &self.resolvers {
607 match resolver.resolve_ip(full).await {
608 Ok(ips) if ips.is_empty() => {}
609 Ok(ips) => {
610 return Ok(ips);
611 }
612 Err(err) => errors.push(err),
613 }
614 }
615
616 if let Some(err) = errors.pop() {
617 Err(err)
618 } else {
619 Ok(vec![])
620 }
621 }
622
623 async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
624 let mut errors = vec![];
625 for resolver in &self.resolvers {
626 match resolver.resolve_mx(full).await {
627 Ok(ips) if ips.is_empty() => {}
628 Ok(ips) => {
629 return Ok(ips);
630 }
631 Err(err) => errors.push(err),
632 }
633 }
634
635 if let Some(err) = errors.pop() {
636 Err(err)
637 } else {
638 Ok(vec![])
639 }
640 }
641
642 async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
643 let mut errors = vec![];
644 for resolver in &self.resolvers {
645 match resolver.resolve_ptr(ip).await {
646 Ok(ips) if ips.is_empty() => {}
647 Ok(ips) => {
648 return Ok(ips);
649 }
650 Err(err) => errors.push(err),
651 }
652 }
653
654 if let Some(err) = errors.pop() {
655 Err(err)
656 } else {
657 Ok(vec![])
658 }
659 }
660
661 async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
662 let mut errors = vec![];
663 let mut nxdomain = None;
664 for resolver in &self.resolvers {
665 match resolver.resolve(name.clone(), rrtype).await {
666 Ok(answer) => {
667 if !answer.nxdomain {
668 return Ok(answer);
669 }
670 nxdomain.replace(answer);
671 }
672 Err(err) => errors.push(err),
673 }
674 }
675
676 if let Some(err) = errors.pop() {
677 Err(err)
678 } else {
679 match nxdomain.take() {
680 Some(answer) => Ok(answer),
681 None => Ok(Answer {
682 canon_name: None,
683 records: vec![],
684 nxdomain: true,
685 secure: false,
686 bogus: false,
687 why_bogus: None,
688 expires: Instant::now() + Duration::from_secs(60),
689 response_code: ResponseCode::NXDomain,
690 }),
691 }
692 }
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::ptr_host;
699 use std::net::{Ipv4Addr, Ipv6Addr};
700 use std::str::FromStr;
701
702 #[test]
703 fn test_ptr_host() {
704 assert_eq!(
705 ptr_host(Ipv4Addr::new(192, 0, 2, 1).into()),
706 "1.2.0.192.in-addr.arpa"
707 );
708 assert_eq!(
709 ptr_host(Ipv6Addr::from_str("2001:db8::1").unwrap().into()),
710 "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa"
711 );
712 }
713}