dns_resolver/
resolver.rs

1use async_trait::async_trait;
2use hickory_proto::ProtoErrorKind;
3use hickory_resolver::proto::op::response_code::ResponseCode;
4use hickory_resolver::proto::rr::rdata::{A, AAAA, MX, PTR, TXT};
5#[cfg(feature = "unbound")]
6use hickory_resolver::proto::rr::DNSClass;
7use hickory_resolver::proto::rr::{LowerName, RData, RecordData, RecordSet, RecordType, RrKey};
8use hickory_resolver::proto::serialize::txt::Parser;
9use hickory_resolver::{Name, ResolveError, TokioResolver};
10#[cfg(feature = "unbound")]
11use libunbound::{AsyncContext, Context};
12use std::collections::BTreeMap;
13use std::fmt;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::time::{Duration, Instant};
17use thiserror::Error;
18
19pub struct IpDisplay {
20    pub ip: IpAddr,
21    pub reverse: bool,
22}
23
24impl fmt::Display for IpDisplay {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self.ip {
27            IpAddr::V4(v4) => {
28                let mut bytes = v4.octets();
29                if self.reverse {
30                    bytes.reverse();
31                }
32                let mut first = true;
33                for byte in bytes {
34                    if !first {
35                        f.write_str(".")?;
36                    }
37                    write!(f, "{byte}")?;
38                    first = false;
39                }
40                Ok(())
41            }
42            IpAddr::V6(v6) => {
43                let mut bytes = v6.octets();
44                if self.reverse {
45                    bytes.reverse();
46                }
47                let mut first = true;
48                for byte in bytes {
49                    if !first {
50                        f.write_str(".")?;
51                    }
52                    let (upper, lower) = (byte >> 4, byte & 0xf);
53                    if self.reverse {
54                        write!(f, "{lower:x}.{upper:x}")?;
55                    } else {
56                        write!(f, "{upper:x}.{lower:x}")?;
57                    }
58                    first = false;
59                }
60                Ok(())
61            }
62        }
63    }
64}
65
66pub fn ptr_host(ip: IpAddr) -> String {
67    let mut out = reverse_ip(ip);
68    out.push_str(match ip {
69        IpAddr::V4(_) => ".in-addr.arpa",
70        IpAddr::V6(_) => ".ip6.arpa",
71    });
72    out
73}
74
75pub fn reverse_ip(ip: IpAddr) -> String {
76    IpDisplay { ip, reverse: true }.to_string()
77}
78
79#[derive(Debug)]
80pub struct Answer {
81    pub canon_name: Option<String>,
82    pub records: Vec<RData>,
83    pub nxdomain: bool,
84    pub secure: bool,
85    pub bogus: bool,
86    pub why_bogus: Option<String>,
87    pub expires: Instant,
88    pub response_code: ResponseCode,
89}
90
91impl Answer {
92    pub fn as_txt(&self) -> Vec<String> {
93        let mut result = vec![];
94        for r in &self.records {
95            if let Some(txt) = r.as_txt() {
96                let mut joined = String::new();
97                for t in txt.iter() {
98                    joined.push_str(&String::from_utf8_lossy(t));
99                }
100                result.push(joined);
101            }
102        }
103        result
104    }
105
106    pub fn as_addr(&self) -> Vec<IpAddr> {
107        let mut result = vec![];
108        for r in &self.records {
109            if let Some(a) = r.as_a() {
110                result.push(a.0.into());
111            } else if let Some(a) = r.as_aaaa() {
112                result.push(a.0.into());
113            }
114        }
115        result
116    }
117}
118
119#[derive(Clone, Debug, Error, PartialEq)]
120pub enum DnsError {
121    #[error("invalid DNS name: {0}")]
122    InvalidName(String),
123    #[error("DNS: {0}")]
124    ResolveFailed(String),
125}
126
127impl DnsError {
128    pub(crate) fn from_resolve(name: &impl fmt::Display, err: ResolveError) -> Self {
129        DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
130    }
131}
132
133#[async_trait]
134pub trait Resolver: Send + Sync + 'static {
135    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError>;
136
137    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError>;
138
139    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError>;
140
141    async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
142        let name = Name::from_str_relaxed(name)
143            .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
144        self.resolve(name, RecordType::TXT).await
145    }
146
147    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError>;
148}
149
150#[async_trait]
151impl Resolver for Box<dyn Resolver> {
152    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
153        (**self).resolve_ip(host).await
154    }
155
156    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
157        (**self).resolve_mx(host).await
158    }
159
160    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
161        (**self).resolve_ptr(ip).await
162    }
163
164    async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
165        (**self).resolve_txt(name).await
166    }
167
168    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
169        (**self).resolve(name, rrtype).await
170    }
171}
172
173#[derive(Debug, Default)]
174pub struct TestResolver {
175    records: BTreeMap<Name, BTreeMap<RrKey, RecordSet>>,
176}
177
178impl TestResolver {
179    pub fn with_zone(mut self, zone: &str) -> Result<Self, String> {
180        let (mut name, records) = Parser::new(zone, None, None)
181            .parse()
182            .map_err(|err| format!("{err:#}"))?;
183        // The parser can create results with varying FQDN-ness, so let's
184        // ensure that they're all marked as FQDN, otherwise our get()
185        // function can fail to resolve data from the zone.
186        name.set_fqdn(true);
187        let fqdn_records = records
188            .into_iter()
189            .map(|(key, value)| {
190                if key.name().is_fqdn() {
191                    (key, value)
192                } else {
193                    let mut name: Name = key.name().into();
194                    name.set_fqdn(true);
195                    (RrKey::new(LowerName::new(&name), key.record_type), value)
196                }
197            })
198            .collect();
199        self.records.insert(name, fqdn_records);
200        Ok(self)
201    }
202
203    pub fn with_txt(self, domain: &str, value: impl Into<String>) -> Self {
204        self.with_txt_multiple(domain, vec![value.into()])
205    }
206
207    /// Add multiple separate TXT records for the specified domain
208    pub fn with_txt_multiple(mut self, domain: &str, value: Vec<String>) -> Self {
209        let fqdn = format!("{}.", domain);
210        let authority = Name::from_str_relaxed(&fqdn).unwrap();
211        let key = RrKey {
212            name: LowerName::from_str(&fqdn).unwrap(),
213            record_type: RecordType::TXT,
214        };
215
216        let mut records = RecordSet::new(authority.clone(), RecordType::TXT, 0);
217        for item in value {
218            records.add_rdata(RData::TXT(TXT::new(vec![item])));
219        }
220        self.records
221            .entry(authority)
222            .or_default()
223            .insert(key, records);
224
225        self
226    }
227
228    fn get(&self, full: &Name, record_type: RecordType) -> Result<Answer, DnsError> {
229        let mut full_fqdn = full.clone();
230        full_fqdn.set_fqdn(true);
231        let mut authority = full_fqdn.clone();
232
233        let records = loop {
234            if let Some(records) = self.records.get(&authority) {
235                break records;
236            };
237
238            if authority.num_labels() > 1 {
239                authority = authority.base_name();
240                continue;
241            }
242
243            return Ok(Answer {
244                canon_name: None,
245                records: vec![],
246                nxdomain: true,
247                secure: false,
248                bogus: false,
249                why_bogus: None,
250                expires: Instant::now() + Duration::from_secs(60),
251                response_code: ResponseCode::NXDomain,
252            });
253        };
254
255        let records = records.get(&RrKey {
256            name: LowerName::from(&full_fqdn),
257            record_type,
258        });
259
260        let Some(records) = records else {
261            return Ok(Answer {
262                canon_name: None,
263                records: vec![],
264                nxdomain: false,
265                secure: false,
266                bogus: false,
267                why_bogus: None,
268                expires: Instant::now() + Duration::from_secs(60),
269                response_code: ResponseCode::NoError,
270            });
271        };
272
273        Ok(Answer {
274            canon_name: None,
275            records: records
276                .records_without_rrsigs()
277                .map(|r| r.data().clone())
278                .collect(),
279            nxdomain: false,
280            secure: false,
281            bogus: false,
282            why_bogus: None,
283            expires: Instant::now() + Duration::from_secs(60),
284            response_code: ResponseCode::NoError,
285        })
286    }
287}
288
289#[async_trait]
290impl Resolver for TestResolver {
291    async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
292        let name = Name::from_str_relaxed(full)
293            .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
294
295        let mut values = vec![];
296        let answer = self.get(&name, RecordType::A)?;
297        for record in answer.records {
298            let a = A::try_borrow(&record).unwrap();
299            values.push(IpAddr::V4(a.0));
300        }
301
302        let answer = self.get(&name, RecordType::AAAA)?;
303        for record in answer.records {
304            let a = AAAA::try_borrow(&record).unwrap();
305            values.push(IpAddr::V6(a.0));
306        }
307
308        Ok(values)
309    }
310
311    async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
312        let name = Name::from_str_relaxed(full)
313            .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
314
315        let mut values = vec![];
316        let answer = self.get(&name, RecordType::MX)?;
317        for record in answer.records {
318            let mx = MX::try_borrow(&record).unwrap();
319            values.push(mx.exchange().clone());
320        }
321
322        Ok(values)
323    }
324
325    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
326        let name = ptr_host(ip);
327        let name = Name::from_str_relaxed(ptr_host(ip))
328            .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
329
330        let mut values = vec![];
331        let answer = self.get(&name, RecordType::PTR)?;
332        for record in answer.records {
333            match PTR::try_borrow(&record) {
334                Some(ptr) => values.push(ptr.0.clone()),
335                None => {
336                    println!("invalid record found for PTR record for {ip}");
337                    return Err(DnsError::ResolveFailed(format!(
338                        "invalid record found for PTR record for {ip}"
339                    )));
340                }
341            };
342        }
343
344        Ok(values)
345    }
346
347    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
348        self.get(&name, rrtype)
349    }
350}
351
352#[cfg(feature = "unbound")]
353pub struct UnboundResolver {
354    cx: AsyncContext,
355}
356
357#[cfg(feature = "unbound")]
358impl UnboundResolver {
359    pub fn new() -> Result<Self, libunbound::Error> {
360        // This resolves directly against the root
361        let context = Context::new()?;
362        // and enables DNSSEC
363        context.add_builtin_trust_anchors()?;
364        Ok(Self {
365            cx: context.into_async()?,
366        })
367    }
368}
369
370#[cfg(feature = "unbound")]
371#[async_trait]
372impl Resolver for UnboundResolver {
373    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
374        let (a, aaaa) = tokio::join!(
375            self.cx.resolve(host, RecordType::A, DNSClass::IN),
376            self.cx.resolve(host, RecordType::AAAA, DNSClass::IN),
377        );
378
379        let mut records = vec![];
380        match (a, aaaa) {
381            (Ok(a), Ok(aaaa)) => {
382                records.extend(a.rdata().filter_map(|r| match r {
383                    Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
384                    Err(_) => None,
385                }));
386                records.extend(aaaa.rdata().filter_map(|r| match r {
387                    Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
388                    Err(_) => None,
389                }));
390            }
391            (Ok(a), Err(_)) => {
392                records.extend(a.rdata().filter_map(|r| match r {
393                    Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
394                    Err(_) => None,
395                }));
396            }
397            (Err(_), Ok(aaaa)) => {
398                records.extend(aaaa.rdata().filter_map(|r| match r {
399                    Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
400                    Err(_) => None,
401                }));
402            }
403            (Err(err), Err(_)) => {
404                return Err(DnsError::ResolveFailed(format!(
405                    "failed to query DNS for {host}: {err}"
406                )))
407            }
408        }
409
410        Ok(records)
411    }
412
413    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
414        let answer = self
415            .cx
416            .resolve(host, RecordType::A, DNSClass::IN)
417            .await
418            .map_err(|err| {
419                DnsError::ResolveFailed(format!("failed to query DNS for {host}: {err}"))
420            })?;
421
422        Ok(answer
423            .rdata()
424            .filter_map(|r| match r {
425                Ok(r) => r.as_mx().map(|mx| mx.exchange().clone()),
426                Err(_) => None,
427            })
428            .collect())
429    }
430
431    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
432        let name = ptr_host(ip);
433        let answer = self
434            .cx
435            .resolve(&name, RecordType::PTR, DNSClass::IN)
436            .await
437            .map_err(|err| {
438                DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
439            })?;
440
441        Ok(answer
442            .rdata()
443            .filter_map(|r| match r {
444                Ok(r) => r.as_ptr().map(|ptr| ptr.0.clone()),
445                Err(_) => None,
446            })
447            .collect())
448    }
449
450    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
451        let name = name.to_ascii();
452        let answer = self
453            .cx
454            .resolve(&name, rrtype, DNSClass::IN)
455            .await
456            .map_err(|err| {
457                DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
458            })?;
459
460        let mut records = vec![];
461        for r in answer.rdata() {
462            if let Ok(r) = r {
463                records.push(r);
464            }
465        }
466
467        Ok(Answer {
468            canon_name: answer.canon_name().map(|s| s.to_string()),
469            records,
470            nxdomain: answer.nxdomain(),
471            secure: answer.secure(),
472            bogus: answer.bogus(),
473            why_bogus: answer.why_bogus().map(|s| s.to_string()),
474            response_code: answer.rcode(),
475            expires: Instant::now() + Duration::from_secs(answer.ttl() as u64),
476        })
477    }
478}
479
480#[cfg(feature = "unbound")]
481impl From<AsyncContext> for UnboundResolver {
482    fn from(cx: AsyncContext) -> Self {
483        Self { cx }
484    }
485}
486
487pub struct HickoryResolver {
488    inner: TokioResolver,
489}
490
491impl HickoryResolver {
492    pub fn new() -> Result<Self, hickory_resolver::ResolveError> {
493        Ok(Self {
494            inner: TokioResolver::builder_tokio()?.build(),
495        })
496    }
497}
498
499#[async_trait]
500impl Resolver for HickoryResolver {
501    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
502        let name = Name::from_str_relaxed(host)
503            .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
504
505        match self.inner.lookup_ip(name.clone()).await {
506            Ok(result) => Ok(result.into_iter().collect()),
507            Err(err) => match err.proto().map(|err| err.kind()) {
508                Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
509                _ => Err(DnsError::from_resolve(&name, err)),
510            },
511        }
512    }
513
514    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
515        let name = Name::from_str_relaxed(host)
516            .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
517
518        match self.inner.mx_lookup(name.clone()).await {
519            Ok(result) => Ok(result.into_iter().map(|mx| mx.exchange().clone()).collect()),
520            Err(err) => match err.proto().map(|err| err.kind()) {
521                Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
522                _ => Err(DnsError::from_resolve(&name, err)),
523            },
524        }
525    }
526
527    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
528        match self.inner.reverse_lookup(ip).await {
529            Ok(result) => Ok(result.into_iter().map(|ptr| ptr.0).collect()),
530            Err(err) => match err.proto().map(|err| err.kind()) {
531                Some(ProtoErrorKind::NoRecordsFound { .. }) => Ok(vec![]),
532                _ => Err(DnsError::from_resolve(&ip, err)),
533            },
534        }
535    }
536
537    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
538        match self.inner.lookup(name.clone(), rrtype).await {
539            Ok(result) => {
540                let expires = result.valid_until();
541                let records = result.iter().cloned().collect();
542                Ok(Answer {
543                    canon_name: None,
544                    records,
545                    nxdomain: false,
546                    secure: false,
547                    bogus: false,
548                    why_bogus: None,
549                    expires,
550                    response_code: ResponseCode::NoError,
551                })
552            }
553            Err(err) => match err.proto().map(|err| err.kind()) {
554                Some(ProtoErrorKind::NoRecordsFound {
555                    negative_ttl,
556                    response_code,
557                    ..
558                }) => Ok(Answer {
559                    canon_name: None,
560                    records: vec![],
561                    nxdomain: *response_code == ResponseCode::NXDomain,
562                    secure: false,
563                    bogus: false,
564                    why_bogus: None,
565                    response_code: *response_code,
566                    expires: Instant::now()
567                        + Duration::from_secs(negative_ttl.unwrap_or(60) as u64),
568                }),
569                _ => Err(DnsError::from_resolve(&name, err)),
570            },
571        }
572    }
573}
574
575impl From<TokioResolver> for HickoryResolver {
576    fn from(inner: TokioResolver) -> Self {
577        Self { inner }
578    }
579}
580
581/// AggregateResolver aggregates the results from multiple
582/// resolver instances.
583/// This is most useful when you want to overlay or otherwise
584/// force in test data to take precedence over real resolver
585/// results.  In that situation, you'd push a TestResolver
586/// ahead of the HickoryResolver that you want to use for
587/// real DNS resolution.
588pub struct AggregateResolver {
589    resolvers: Vec<Box<dyn Resolver>>,
590}
591
592impl AggregateResolver {
593    pub fn new() -> Self {
594        Self { resolvers: vec![] }
595    }
596
597    pub fn push_resolver(&mut self, resolver: Box<dyn Resolver>) {
598        self.resolvers.push(resolver);
599    }
600}
601
602#[async_trait]
603impl Resolver for AggregateResolver {
604    async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
605        let mut errors = vec![];
606        for resolver in &self.resolvers {
607            match resolver.resolve_ip(full).await {
608                Ok(ips) if ips.is_empty() => {}
609                Ok(ips) => {
610                    return Ok(ips);
611                }
612                Err(err) => errors.push(err),
613            }
614        }
615
616        if let Some(err) = errors.pop() {
617            Err(err)
618        } else {
619            Ok(vec![])
620        }
621    }
622
623    async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
624        let mut errors = vec![];
625        for resolver in &self.resolvers {
626            match resolver.resolve_mx(full).await {
627                Ok(ips) if ips.is_empty() => {}
628                Ok(ips) => {
629                    return Ok(ips);
630                }
631                Err(err) => errors.push(err),
632            }
633        }
634
635        if let Some(err) = errors.pop() {
636            Err(err)
637        } else {
638            Ok(vec![])
639        }
640    }
641
642    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
643        let mut errors = vec![];
644        for resolver in &self.resolvers {
645            match resolver.resolve_ptr(ip).await {
646                Ok(ips) if ips.is_empty() => {}
647                Ok(ips) => {
648                    return Ok(ips);
649                }
650                Err(err) => errors.push(err),
651            }
652        }
653
654        if let Some(err) = errors.pop() {
655            Err(err)
656        } else {
657            Ok(vec![])
658        }
659    }
660
661    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
662        let mut errors = vec![];
663        let mut nxdomain = None;
664        for resolver in &self.resolvers {
665            match resolver.resolve(name.clone(), rrtype).await {
666                Ok(answer) => {
667                    if !answer.nxdomain {
668                        return Ok(answer);
669                    }
670                    nxdomain.replace(answer);
671                }
672                Err(err) => errors.push(err),
673            }
674        }
675
676        if let Some(err) = errors.pop() {
677            Err(err)
678        } else {
679            match nxdomain.take() {
680                Some(answer) => Ok(answer),
681                None => Ok(Answer {
682                    canon_name: None,
683                    records: vec![],
684                    nxdomain: true,
685                    secure: false,
686                    bogus: false,
687                    why_bogus: None,
688                    expires: Instant::now() + Duration::from_secs(60),
689                    response_code: ResponseCode::NXDomain,
690                }),
691            }
692        }
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::ptr_host;
699    use std::net::{Ipv4Addr, Ipv6Addr};
700    use std::str::FromStr;
701
702    #[test]
703    fn test_ptr_host() {
704        assert_eq!(
705            ptr_host(Ipv4Addr::new(192, 0, 2, 1).into()),
706            "1.2.0.192.in-addr.arpa"
707        );
708        assert_eq!(
709            ptr_host(Ipv6Addr::from_str("2001:db8::1").unwrap().into()),
710            "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa"
711        );
712    }
713}