dns_resolver/
resolver.rs

1use async_trait::async_trait;
2use hickory_proto::ProtoErrorKind;
3use hickory_resolver::proto::op::response_code::ResponseCode;
4use hickory_resolver::proto::rr::rdata::{A, AAAA, MX, PTR, TXT};
5#[cfg(feature = "unbound")]
6use hickory_resolver::proto::rr::DNSClass;
7use hickory_resolver::proto::rr::{LowerName, RData, RecordData, RecordSet, RecordType, RrKey};
8use hickory_resolver::proto::serialize::txt::Parser;
9use hickory_resolver::{Name, ResolveError, TokioResolver};
10#[cfg(feature = "unbound")]
11use libunbound::{AsyncContext, Context};
12use std::collections::BTreeMap;
13use std::fmt;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::time::{Duration, Instant};
17use thiserror::Error;
18
19pub struct IpDisplay {
20    pub ip: IpAddr,
21    pub reverse: bool,
22}
23
24impl fmt::Display for IpDisplay {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self.ip {
27            IpAddr::V4(v4) => {
28                let mut bytes = v4.octets();
29                if self.reverse {
30                    bytes.reverse();
31                }
32                let mut first = true;
33                for byte in bytes {
34                    if !first {
35                        f.write_str(".")?;
36                    }
37                    write!(f, "{byte}")?;
38                    first = false;
39                }
40                Ok(())
41            }
42            IpAddr::V6(v6) => {
43                let mut bytes = v6.octets();
44                if self.reverse {
45                    bytes.reverse();
46                }
47                let mut first = true;
48                for byte in bytes {
49                    if !first {
50                        f.write_str(".")?;
51                    }
52                    let (upper, lower) = (byte >> 4, byte & 0xf);
53                    if self.reverse {
54                        write!(f, "{lower:x}.{upper:x}")?;
55                    } else {
56                        write!(f, "{upper:x}.{lower:x}")?;
57                    }
58                    first = false;
59                }
60                Ok(())
61            }
62        }
63    }
64}
65
66pub fn ptr_host(ip: IpAddr) -> String {
67    let mut out = IpDisplay { ip, reverse: true }.to_string();
68    out.push_str(match ip {
69        IpAddr::V4(_) => ".in-addr.arpa",
70        IpAddr::V6(_) => ".ip6.arpa",
71    });
72    out
73}
74
75#[derive(Debug)]
76pub struct Answer {
77    pub canon_name: Option<String>,
78    pub records: Vec<RData>,
79    pub nxdomain: bool,
80    pub secure: bool,
81    pub bogus: bool,
82    pub why_bogus: Option<String>,
83    pub expires: Instant,
84    pub response_code: ResponseCode,
85}
86
87impl Answer {
88    pub fn as_txt(&self) -> Vec<String> {
89        let mut result = vec![];
90        for r in &self.records {
91            if let Some(txt) = r.as_txt() {
92                let mut joined = String::new();
93                for t in txt.iter() {
94                    joined.push_str(&String::from_utf8_lossy(t));
95                }
96                result.push(joined);
97            }
98        }
99        result
100    }
101
102    pub fn as_addr(&self) -> Vec<IpAddr> {
103        let mut result = vec![];
104        for r in &self.records {
105            if let Some(a) = r.as_a() {
106                result.push(a.0.into());
107            } else if let Some(a) = r.as_aaaa() {
108                result.push(a.0.into());
109            }
110        }
111        result
112    }
113}
114
115#[derive(Clone, Debug, Error, PartialEq)]
116pub enum DnsError {
117    #[error("invalid DNS name: {0}")]
118    InvalidName(String),
119    #[error("DNS: {0}")]
120    ResolveFailed(String),
121}
122
123impl DnsError {
124    pub(crate) fn from_resolve(name: &impl fmt::Display, err: ResolveError) -> Self {
125        DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
126    }
127}
128
129#[async_trait]
130pub trait Resolver: Send + Sync + 'static {
131    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError>;
132
133    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError>;
134
135    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError>;
136
137    async fn resolve_txt(&self, name: &str) -> Result<Answer, DnsError> {
138        let name = Name::from_utf8(name)
139            .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
140        self.resolve(name, RecordType::TXT).await
141    }
142
143    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError>;
144}
145
146#[derive(Debug, Default)]
147pub struct TestResolver {
148    records: BTreeMap<Name, BTreeMap<RrKey, RecordSet>>,
149}
150
151impl TestResolver {
152    pub fn with_zone(mut self, zone: &str) -> Self {
153        let (mut name, records) = Parser::new(zone, None, None).parse().unwrap();
154        // The parser can create results with varying FQDN-ness, so let's
155        // ensure that they're all marked as FQDN, otherwise our get()
156        // function can fail to resolve data from the zone.
157        name.set_fqdn(true);
158        let fqdn_records = records
159            .into_iter()
160            .map(|(key, value)| {
161                if key.name().is_fqdn() {
162                    (key, value)
163                } else {
164                    let mut name: Name = key.name().into();
165                    name.set_fqdn(true);
166                    (RrKey::new(LowerName::new(&name), key.record_type), value)
167                }
168            })
169            .collect();
170        self.records.insert(name, fqdn_records);
171        self
172    }
173
174    pub fn with_txt(self, domain: &str, value: String) -> Self {
175        self.with_txt_multiple(domain, vec![value])
176    }
177
178    /// Add multiple separate TXT records for the specified domain
179    pub fn with_txt_multiple(mut self, domain: &str, value: Vec<String>) -> Self {
180        let fqdn = format!("{}.", domain);
181        let authority = Name::from_str(&fqdn).unwrap();
182        let key = RrKey {
183            name: LowerName::from_str(&fqdn).unwrap(),
184            record_type: RecordType::TXT,
185        };
186
187        let mut records = RecordSet::new(authority.clone(), RecordType::TXT, 0);
188        for item in value {
189            records.add_rdata(RData::TXT(TXT::new(vec![item])));
190        }
191        self.records
192            .entry(authority)
193            .or_default()
194            .insert(key, records);
195
196        self
197    }
198
199    fn get(&self, full: &Name, record_type: RecordType) -> Result<Answer, DnsError> {
200        let mut full_fqdn = full.clone();
201        full_fqdn.set_fqdn(true);
202        let mut authority = full_fqdn.clone();
203
204        let records = loop {
205            if let Some(records) = self.records.get(&authority) {
206                break records;
207            };
208
209            if authority.num_labels() > 1 {
210                authority = authority.base_name();
211                continue;
212            }
213
214            return Ok(Answer {
215                canon_name: None,
216                records: vec![],
217                nxdomain: true,
218                secure: false,
219                bogus: false,
220                why_bogus: None,
221                expires: Instant::now() + Duration::from_secs(60),
222                response_code: ResponseCode::NXDomain,
223            });
224        };
225
226        let records = records.get(&RrKey {
227            name: LowerName::from(&full_fqdn),
228            record_type,
229        });
230
231        let Some(records) = records else {
232            return Ok(Answer {
233                canon_name: None,
234                records: vec![],
235                nxdomain: false,
236                secure: false,
237                bogus: false,
238                why_bogus: None,
239                expires: Instant::now() + Duration::from_secs(60),
240                response_code: ResponseCode::NoError,
241            });
242        };
243
244        Ok(Answer {
245            canon_name: None,
246            records: records
247                .records_without_rrsigs()
248                .map(|r| r.data().clone())
249                .collect(),
250            nxdomain: false,
251            secure: false,
252            bogus: false,
253            why_bogus: None,
254            expires: Instant::now() + Duration::from_secs(60),
255            response_code: ResponseCode::NoError,
256        })
257    }
258}
259
260#[async_trait]
261impl Resolver for TestResolver {
262    async fn resolve_ip(&self, full: &str) -> Result<Vec<IpAddr>, DnsError> {
263        let name = Name::from_utf8(full)
264            .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
265
266        let mut values = vec![];
267        let answer = self.get(&name, RecordType::A)?;
268        for record in answer.records {
269            let a = A::try_borrow(&record).unwrap();
270            values.push(IpAddr::V4(a.0));
271        }
272
273        let answer = self.get(&name, RecordType::AAAA)?;
274        for record in answer.records {
275            let a = AAAA::try_borrow(&record).unwrap();
276            values.push(IpAddr::V6(a.0));
277        }
278
279        Ok(values)
280    }
281
282    async fn resolve_mx(&self, full: &str) -> Result<Vec<Name>, DnsError> {
283        let name = Name::from_utf8(full)
284            .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
285
286        let mut values = vec![];
287        let answer = self.get(&name, RecordType::MX)?;
288        for record in answer.records {
289            let mx = MX::try_borrow(&record).unwrap();
290            values.push(mx.exchange().clone());
291        }
292
293        Ok(values)
294    }
295
296    async fn resolve_txt(&self, full: &str) -> Result<Answer, DnsError> {
297        let name = Name::from_utf8(full)
298            .map_err(|err| DnsError::InvalidName(format!("invalid name {full}: {err}")))?;
299        self.get(&name, RecordType::TXT)
300    }
301
302    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
303        let name = ptr_host(ip);
304        let name = Name::from_utf8(ptr_host(ip))
305            .map_err(|err| DnsError::InvalidName(format!("invalid name {name}: {err}")))?;
306
307        let mut values = vec![];
308        let answer = self.get(&name, RecordType::PTR)?;
309        for record in answer.records {
310            match PTR::try_borrow(&record) {
311                Some(ptr) => values.push(ptr.0.clone()),
312                None => {
313                    println!("invalid record found for PTR record for {ip}");
314                    return Err(DnsError::ResolveFailed(format!(
315                        "invalid record found for PTR record for {ip}"
316                    )));
317                }
318            };
319        }
320
321        Ok(values)
322    }
323
324    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
325        self.get(&name, rrtype)
326    }
327}
328
329#[cfg(feature = "unbound")]
330pub struct UnboundResolver {
331    cx: AsyncContext,
332}
333
334#[cfg(feature = "unbound")]
335impl UnboundResolver {
336    pub fn new() -> Result<Self, libunbound::Error> {
337        // This resolves directly against the root
338        let context = Context::new()?;
339        // and enables DNSSEC
340        context.add_builtin_trust_anchors()?;
341        Ok(Self {
342            cx: context.into_async()?,
343        })
344    }
345}
346
347#[cfg(feature = "unbound")]
348#[async_trait]
349impl Resolver for UnboundResolver {
350    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
351        let (a, aaaa) = tokio::join!(
352            self.cx.resolve(host, RecordType::A, DNSClass::IN),
353            self.cx.resolve(host, RecordType::AAAA, DNSClass::IN),
354        );
355
356        let mut records = vec![];
357        match (a, aaaa) {
358            (Ok(a), Ok(aaaa)) => {
359                records.extend(a.rdata().filter_map(|r| match r {
360                    Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
361                    Err(_) => None,
362                }));
363                records.extend(aaaa.rdata().filter_map(|r| match r {
364                    Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
365                    Err(_) => None,
366                }));
367            }
368            (Ok(a), Err(_)) => {
369                records.extend(a.rdata().filter_map(|r| match r {
370                    Ok(r) => r.as_a().map(|a| IpAddr::from(a.0)),
371                    Err(_) => None,
372                }));
373            }
374            (Err(_), Ok(aaaa)) => {
375                records.extend(aaaa.rdata().filter_map(|r| match r {
376                    Ok(r) => r.as_aaaa().map(|aaaa| IpAddr::from(aaaa.0)),
377                    Err(_) => None,
378                }));
379            }
380            (Err(err), Err(_)) => {
381                return Err(DnsError::ResolveFailed(format!(
382                    "failed to query DNS for {host}: {err}"
383                )))
384            }
385        }
386
387        Ok(records)
388    }
389
390    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
391        let answer = self
392            .cx
393            .resolve(host, RecordType::A, DNSClass::IN)
394            .await
395            .map_err(|err| {
396                DnsError::ResolveFailed(format!("failed to query DNS for {host}: {err}"))
397            })?;
398
399        Ok(answer
400            .rdata()
401            .filter_map(|r| match r {
402                Ok(r) => r.as_mx().map(|mx| mx.exchange().clone()),
403                Err(_) => None,
404            })
405            .collect())
406    }
407
408    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
409        let name = ptr_host(ip);
410        let answer = self
411            .cx
412            .resolve(&name, RecordType::PTR, DNSClass::IN)
413            .await
414            .map_err(|err| {
415                DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
416            })?;
417
418        Ok(answer
419            .rdata()
420            .filter_map(|r| match r {
421                Ok(r) => r.as_ptr().map(|ptr| ptr.0.clone()),
422                Err(_) => None,
423            })
424            .collect())
425    }
426
427    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
428        let name = name.to_ascii();
429        let answer = self
430            .cx
431            .resolve(&name, rrtype, DNSClass::IN)
432            .await
433            .map_err(|err| {
434                DnsError::ResolveFailed(format!("failed to query DNS for {name}: {err}"))
435            })?;
436
437        let mut records = vec![];
438        for r in answer.rdata() {
439            if let Ok(r) = r {
440                records.push(r);
441            }
442        }
443
444        Ok(Answer {
445            canon_name: answer.canon_name().map(|s| s.to_string()),
446            records,
447            nxdomain: answer.nxdomain(),
448            secure: answer.secure(),
449            bogus: answer.bogus(),
450            why_bogus: answer.why_bogus().map(|s| s.to_string()),
451            response_code: answer.rcode(),
452            expires: Instant::now() + Duration::from_secs(answer.ttl() as u64),
453        })
454    }
455}
456
457#[cfg(feature = "unbound")]
458impl From<AsyncContext> for UnboundResolver {
459    fn from(cx: AsyncContext) -> Self {
460        Self { cx }
461    }
462}
463
464pub struct HickoryResolver {
465    inner: TokioResolver,
466}
467
468impl HickoryResolver {
469    pub fn new() -> Result<Self, hickory_resolver::ResolveError> {
470        Ok(Self {
471            inner: TokioResolver::builder_tokio()?.build(),
472        })
473    }
474}
475
476#[async_trait]
477impl Resolver for HickoryResolver {
478    async fn resolve_ip(&self, host: &str) -> Result<Vec<IpAddr>, DnsError> {
479        let name = Name::from_utf8(host)
480            .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
481
482        self.inner
483            .lookup_ip(name)
484            .await
485            .map_err(|err| DnsError::from_resolve(&host, err))?
486            .into_iter()
487            .map(Ok)
488            .collect()
489    }
490
491    async fn resolve_mx(&self, host: &str) -> Result<Vec<Name>, DnsError> {
492        let name = Name::from_utf8(host)
493            .map_err(|err| DnsError::InvalidName(format!("invalid name {host}: {err}")))?;
494
495        self.inner
496            .mx_lookup(name)
497            .await
498            .map_err(|err| DnsError::from_resolve(&host, err))?
499            .into_iter()
500            .map(|mx| Ok(mx.exchange().clone()))
501            .collect()
502    }
503
504    async fn resolve_ptr(&self, ip: IpAddr) -> Result<Vec<Name>, DnsError> {
505        self.inner
506            .reverse_lookup(ip)
507            .await
508            .map_err(|err| DnsError::from_resolve(&ip, err))?
509            .into_iter()
510            .map(|ptr| Ok(ptr.0))
511            .collect()
512    }
513
514    async fn resolve(&self, name: Name, rrtype: RecordType) -> Result<Answer, DnsError> {
515        match self.inner.lookup(name.clone(), rrtype).await {
516            Ok(result) => {
517                let expires = result.valid_until();
518                let records = result.iter().cloned().collect();
519                Ok(Answer {
520                    canon_name: None,
521                    records,
522                    nxdomain: false,
523                    secure: false,
524                    bogus: false,
525                    why_bogus: None,
526                    expires,
527                    response_code: ResponseCode::NoError,
528                })
529            }
530            Err(err) => match err.proto().map(|err| err.kind()) {
531                Some(ProtoErrorKind::NoRecordsFound {
532                    negative_ttl,
533                    response_code,
534                    ..
535                }) => Ok(Answer {
536                    canon_name: None,
537                    records: vec![],
538                    nxdomain: *response_code == ResponseCode::NXDomain,
539                    secure: false,
540                    bogus: false,
541                    why_bogus: None,
542                    response_code: *response_code,
543                    expires: Instant::now()
544                        + Duration::from_secs(negative_ttl.unwrap_or(60) as u64),
545                }),
546                _ => Err(DnsError::from_resolve(&name, err)),
547            },
548        }
549    }
550}
551
552impl From<TokioResolver> for HickoryResolver {
553    fn from(inner: TokioResolver) -> Self {
554        Self { inner }
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::ptr_host;
561    use std::net::{Ipv4Addr, Ipv6Addr};
562    use std::str::FromStr;
563
564    #[test]
565    fn test_ptr_host() {
566        assert_eq!(
567            ptr_host(Ipv4Addr::new(192, 0, 2, 1).into()),
568            "1.2.0.192.in-addr.arpa"
569        );
570        assert_eq!(
571            ptr_host(Ipv6Addr::from_str("2001:db8::1").unwrap().into()),
572            "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa"
573        );
574    }
575}