nom_utils/
lib.rs

1use bstr::{BStr, ByteSlice};
2use hickory_resolver::Name;
3use nom::branch::alt;
4use nom::bytes::complete::{take_while1, take_while_m_n};
5use nom::combinator::{map_res, opt, recognize};
6use nom::error::{context, ContextError, ErrorKind, FromExternalError, ParseError as _};
7use nom::multi::{many0, many1};
8use nom::sequence::pair;
9use nom::{Input, Parser as _};
10use nom_locate::LocatedSpan;
11use std::fmt::{self, Debug, Write};
12use std::hash::Hash;
13use std::marker::PhantomData;
14use std::net::{Ipv4Addr, Ipv6Addr};
15use std::str::FromStr;
16
17pub type Span<'a> = LocatedSpan<&'a [u8]>;
18pub type IResult<'a, A, B> = nom::IResult<A, B, ParseError<Span<'a>>>;
19
20pub fn make_span(s: &'_ [u8]) -> Span<'_> {
21    Span::new(s)
22}
23
24/// Like nom::bytes::complete::tag, except that we print what the tag
25/// was expecting if there was an error.
26/// I feel like this should be the default behavior TBH.
27pub fn tag<E>(tag: &'static str) -> TagParser<E> {
28    TagParser {
29        tag,
30        no_case: false,
31        e: PhantomData,
32    }
33}
34
35pub fn tag_no_case<E>(tag: &'static str) -> TagParser<E> {
36    TagParser {
37        tag,
38        no_case: true,
39        e: PhantomData,
40    }
41}
42
43/// Struct to support displaying better errors for tag()
44pub struct TagParser<E> {
45    tag: &'static str,
46    no_case: bool,
47    e: PhantomData<E>,
48}
49
50/// All this fuss to show what we expected for the TagParser impl
51impl<I, Error: nom::error::ParseError<I> + nom::error::FromExternalError<I, String>> nom::Parser<I>
52    for TagParser<Error>
53where
54    I: nom::Input + nom::Compare<&'static str> + nom::AsBytes,
55{
56    type Output = I;
57    type Error = Error;
58
59    fn process<OM: nom::OutputMode>(
60        &mut self,
61        i: I,
62    ) -> nom::PResult<OM, I, Self::Output, Self::Error> {
63        use nom::error::ErrorKind;
64        use nom::{CompareResult, Err, Mode};
65
66        let tag_len = self.tag.input_len();
67
68        let compare_result = if self.no_case {
69            i.compare_no_case(self.tag)
70        } else {
71            i.compare(self.tag)
72        };
73
74        match compare_result {
75            CompareResult::Ok => Ok((i.take_from(tag_len), OM::Output::bind(|| i.take(tag_len)))),
76            CompareResult::Incomplete => Err(Err::Error(OM::Error::bind(|| {
77                Error::from_external_error(
78                    i,
79                    ErrorKind::Fail,
80                    format!(
81                        "expected \"{}\" but ran out of input",
82                        self.tag.escape_debug()
83                    ),
84                )
85            }))),
86
87            CompareResult::Error => {
88                let available = i.take(i.input_len().min(tag_len));
89                Err(Err::Error(OM::Error::bind(|| {
90                    Error::from_external_error(
91                        i,
92                        ErrorKind::Fail,
93                        format!(
94                            "expected \"{}\" but found {:?}",
95                            self.tag.escape_debug(),
96                            BStr::new(available.as_bytes())
97                        ),
98                    )
99                })))
100            }
101        }
102    }
103}
104
105#[derive(Debug)]
106pub enum ParseErrorKind {
107    Context(&'static str),
108    Char(char),
109    Nom(ErrorKind),
110    External { kind: ErrorKind, reason: String },
111}
112
113#[derive(Debug)]
114pub struct ParseError<I: Debug> {
115    pub errors: Vec<(I, ParseErrorKind)>,
116}
117
118impl<I: Debug> ContextError<I> for ParseError<I> {
119    fn add_context(input: I, ctx: &'static str, mut other: Self) -> Self {
120        other.errors.push((input, ParseErrorKind::Context(ctx)));
121        other
122    }
123}
124
125impl<I: Debug> nom::error::ParseError<I> for ParseError<I> {
126    fn from_error_kind(input: I, kind: ErrorKind) -> Self {
127        Self {
128            errors: vec![(input, ParseErrorKind::Nom(kind))],
129        }
130    }
131
132    fn append(input: I, kind: ErrorKind, mut other: Self) -> Self {
133        other.errors.push((input, ParseErrorKind::Nom(kind)));
134        other
135    }
136
137    fn from_char(input: I, c: char) -> Self {
138        Self {
139            errors: vec![(input, ParseErrorKind::Char(c))],
140        }
141    }
142}
143
144impl<I: Debug, E: std::fmt::Display> nom::error::FromExternalError<I, E> for ParseError<I> {
145    fn from_external_error(input: I, kind: ErrorKind, err: E) -> Self {
146        Self {
147            errors: vec![(
148                input,
149                ParseErrorKind::External {
150                    kind,
151                    reason: format!("{err:#}"),
152                },
153            )],
154        }
155    }
156}
157
158pub fn make_context_error<S: Into<String>>(
159    input: Span<'_>,
160    reason: S,
161) -> nom::Err<ParseError<Span<'_>>> {
162    nom::Err::Error(ParseError {
163        errors: vec![(
164            input,
165            ParseErrorKind::External {
166                kind: nom::error::ErrorKind::Fail,
167                reason: reason.into(),
168            },
169        )],
170    })
171}
172
173pub fn explain_nom(input: Span, err: nom::Err<ParseError<Span<'_>>>) -> String {
174    match err {
175        nom::Err::Error(e) => {
176            let mut result = String::new();
177            let mut lines_shown = vec![];
178
179            for (span, kind) in e.errors.iter() {
180                if input.is_empty() {
181                    match kind {
182                        ParseErrorKind::Char(c) => {
183                            write!(&mut result, "Error expected '{c}', got empty input\n\n")
184                        }
185                        ParseErrorKind::Context(s) => {
186                            write!(&mut result, "Error in {s}, got empty input\n\n")
187                        }
188                        ParseErrorKind::External { kind, reason } => {
189                            write!(&mut result, "Error {reason} {kind:?}, got empty input\n\n")
190                        }
191                        ParseErrorKind::Nom(e) => {
192                            write!(&mut result, "Error in {e:?}, got empty input\n\n")
193                        }
194                    }
195                    .ok();
196                    continue;
197                }
198
199                let line_number = span.location_line();
200                let input_line = span.get_line_beginning();
201                // Remap \t in particular, because it can render as multiple
202                // columns and defeat the column number calculation provided
203                // by the Span type
204                let mut line = String::new();
205                for (start, end, c) in input_line.char_indices() {
206                    let c = match c {
207                        '\t' => '\u{2409}',
208                        '\r' => '\u{240d}',
209                        '\n' => '\u{240a}',
210                        c => c,
211                    };
212
213                    if c == std::char::REPLACEMENT_CHARACTER {
214                        let bytes = &input_line[start..end];
215                        for b in bytes.iter() {
216                            line.push_str(&format!("\\x{b:02X}"));
217                        }
218                    } else {
219                        line.push(c);
220                    }
221                }
222
223                let column = span.get_utf8_column();
224
225                lines_shown.push(line_number);
226
227                let mut caret = " ".repeat(column.saturating_sub(1));
228                caret.push('^');
229                for _ in 1..span.fragment().len() {
230                    caret.push('_')
231                }
232
233                match kind {
234                    ParseErrorKind::Char(expected) => {
235                        if let Some(actual) = span.fragment().chars().next() {
236                            write!(
237                                &mut result,
238                                "Error at line {line_number}:\n\
239                                    {line}\n\
240                                    {caret}\n\
241                                    expected '{expected}', found {actual}\n\n",
242                            )
243                        } else {
244                            write!(
245                                &mut result,
246                                "Error at line {line_number}:\n\
247                                    {line}\n\
248                                    {caret}\n\
249                                    expected '{expected}', got end of input\n\n",
250                            )
251                        }
252                    }
253                    ParseErrorKind::Context(context) => {
254                        write!(&mut result, "while parsing {context}\n")
255                    }
256                    ParseErrorKind::External { kind: _, reason } => {
257                        write!(
258                            &mut result,
259                            "Error at line {line_number}, {reason}:\n\
260                                {line}\n\
261                                {caret}\n\n",
262                        )
263                    }
264                    ParseErrorKind::Nom(nom_err) => {
265                        write!(
266                            &mut result,
267                            "Error at line {line_number}, in {nom_err:?}:\n\
268                                {line}\n\
269                                {caret}\n\n",
270                        )
271                    }
272                }
273                .ok();
274            }
275            result
276        }
277        _ => format!("{err:#}"),
278    }
279}
280
281/// See the following RFCs:
282/// * <https://datatracker.ietf.org/doc/html/rfc6531#section-3.3>
283/// * <https://datatracker.ietf.org/doc/html/rfc6532#section-3.1>
284/// * <https://datatracker.ietf.org/doc/html/rfc3629#section-4>
285/// which define a bunch of ABNF, but then caps it off with:
286/// > The authoritative definition of UTF-8 is in [UNICODE].  This
287/// > grammar is believed to describe the same thing Unicode describes, but
288/// > does not claim to be authoritative.  Implementors are urged to rely
289/// > on the authoritative source, rather than on this ABNF.
290pub fn utf8_non_ascii(input: Span) -> IResult<Span, Span> {
291    use nom::Err;
292
293    match input.char_indices().next() {
294        Some((start, end, c)) => {
295            let len = end - start;
296            if c as u32 <= 0x7f {
297                // It's ASCII, therefore doesn't match as utf8_non_ascii
298                return Err(Err::Error(ParseError::from_error_kind(
299                    input,
300                    ErrorKind::Fail,
301                )));
302            }
303            let slice = &input[start..end];
304            if c == std::char::REPLACEMENT_CHARACTER {
305                let mut verify = [0u8; 4];
306                if slice != c.encode_utf8(&mut verify).as_bytes() {
307                    // The original sequence wasn't REPLACEMENT_CHARACTER,
308                    // therefore the input is not valid UTF-8
309                    return Err(Err::Error(ParseError::from_error_kind(
310                        input,
311                        ErrorKind::Fail,
312                    )));
313                }
314            }
315            // slice is the first UTF-8 character in the input
316            Ok((input.take_from(len), input.take(len)))
317        }
318        None => {
319            // There's no input, therefore we cannot match
320            Err(Err::Error(ParseError::from_error_kind(
321                input,
322                ErrorKind::Eof,
323            )))
324        }
325    }
326}
327
328fn snum(input: Span) -> IResult<Span, Span> {
329    take_while_m_n(1, 3, |c: u8| c.is_ascii_digit()).parse(input)
330}
331
332pub fn ipv4_address(input: Span) -> IResult<Span, Ipv4Addr> {
333    context(
334        "ipv4_address",
335        map_res(
336            recognize((snum, tag("."), snum, tag("."), snum, tag("."), snum)),
337            |matched| {
338                let v4str = std::str::from_utf8(&matched).expect("can only be ascii");
339                v4str.parse().map_err(|err| {
340                    nom::Err::Error(ParseError::from_external_error(
341                        input,
342                        ErrorKind::Fail,
343                        format!("invalid ipv4_address: {err}"),
344                    ))
345                })
346            },
347        ),
348    )
349    .parse(input)
350}
351
352pub fn ipv6_address(input: Span) -> IResult<Span, Ipv6Addr> {
353    context(
354        "ipv6_address",
355        map_res(
356            take_while1(|c: u8| c.is_ascii_hexdigit() || c == b':' || c == b'.'),
357            |matched: Span| {
358                let v6str = std::str::from_utf8(&matched).expect("can only be ascii");
359                v6str.parse().map_err(|err| {
360                    nom::Err::Error(ParseError::from_external_error(
361                        input,
362                        ErrorKind::Fail,
363                        format!("invalid ipv6_address: {err}"),
364                    ))
365                })
366            },
367        ),
368    )
369    .parse(input)
370}
371
372/// A validated DNS domain name, stored in normalized (ASCII/punycode) form.
373/// The original wire-format string (which may have been a UTF-8 U-label)
374/// is not preserved; only the IDNA-normalized A-label form is kept.
375#[derive(Clone, Debug, PartialEq, Eq, Hash)]
376pub struct DomainString(String);
377
378impl fmt::Display for DomainString {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        f.write_str(&self.0)
381    }
382}
383
384impl DomainString {
385    pub fn name(&self) -> Name {
386        Name::from_str_relaxed(&self.0)
387            .expect("cannot construct DomainString with an invalid domain name")
388    }
389
390    /// Returns a reference to the normalized (ASCII/punycode) domain string.
391    pub fn as_str(&self) -> &str {
392        &self.0
393    }
394}
395
396impl FromStr for DomainString {
397    type Err = String;
398    fn from_str(s: &str) -> Result<Self, Self::Err> {
399        let name = Name::from_str_relaxed(s)?;
400        Ok(Self(name.to_ascii()))
401    }
402}
403
404impl From<DomainString> for Name {
405    fn from(val: DomainString) -> Self {
406        val.name()
407    }
408}
409
410impl From<&DomainString> for Name {
411    fn from(val: &DomainString) -> Self {
412        val.name()
413    }
414}
415
416/// `let-dig = ALPHA / DIGIT / UTF8-non-ASCII`
417fn let_dig(input: Span) -> IResult<Span, Span> {
418    recognize(alt((
419        take_while_m_n(1, 1, |c: u8| c.is_ascii_alphanumeric()),
420        utf8_non_ascii,
421    )))
422    .parse(input)
423}
424
425/// `ldh-str = *( ALPHA / DIGIT / "-" / UTF8-non-ASCII )`  (one or more)
426///
427/// As an extension to the mail RFCs, we allow for underscore
428/// in domain names, as those are a commonly deployed name, despite it
429/// being in violation of the DNS RFCs.
430fn ldh_str(input: Span) -> IResult<Span, Span> {
431    recognize(many1(alt((
432        take_while_m_n(1, 1, |c: u8| {
433            c.is_ascii_alphanumeric() || c == b'-' || c == b'_'
434        }),
435        utf8_non_ascii,
436    ))))
437    .parse(input)
438}
439
440/// `sub-domain = let-dig [ ldh-str ]`
441fn sub_domain(input: Span) -> IResult<Span, Span> {
442    recognize(pair(let_dig, opt(ldh_str))).parse(input)
443}
444
445/// `domain = sub-domain *( "." sub-domain )`
446pub fn domain_name(input: Span) -> IResult<Span, DomainString> {
447    context(
448        "domain-name",
449        map_res(
450            recognize(pair(sub_domain, many0(pair(tag("."), sub_domain)))),
451            |matched: Span| match std::str::from_utf8(&matched) {
452                Ok(s) => s.parse().map_err(|err| {
453                    nom::Err::Error(ParseError::from_external_error(
454                        input,
455                        ErrorKind::Fail,
456                        format!("invalid domain name: {err}"),
457                    ))
458                }),
459                Err(err) => Err(nom::Err::Error(ParseError::from_external_error(
460                    input,
461                    ErrorKind::Fail,
462                    format!("invalid domain name: {err}"),
463                ))),
464            },
465        ),
466    )
467    .parse(input)
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_ipv4_parse() {
476        // ipv4_address should parse valid IPv4 addresses
477        let (_, addr) = ipv4_address(make_span(b"192.168.1.1")).unwrap();
478        k9::assert_equal!(addr, Ipv4Addr::new(192, 168, 1, 1));
479    }
480
481    #[test]
482    fn test_ipv6_parse() {
483        // ipv6_address should parse valid IPv6 addresses,
484        // and different representations of the same address should be equal
485        let (_, v6a) = ipv6_address(make_span(b"2001:0db8:0000:0000:0000:0000:0000:0001")).unwrap();
486        let (_, v6b) = ipv6_address(make_span(b"2001:db8::1")).unwrap();
487        k9::assert_equal!(v6a, v6b);
488    }
489
490    #[test]
491    fn test_domain_string_partial_eq() {
492        // DomainString should compare equal if they normalize to the same domain
493        let d1 = DomainString::from_str("EXAMPLE.COM").unwrap();
494        let d2 = DomainString::from_str("example.com").unwrap();
495
496        assert_eq!(d1, d2);
497    }
498
499    #[test]
500    fn test_domain_string_partial_eq_idna() {
501        // DomainString should compare equal after IDNA normalization
502        let d1 = DomainString::from_str("münchen.de").unwrap();
503        let d2 = DomainString::from_str("xn--mnchen-3ya.de").unwrap();
504
505        assert_eq!(d1, d2);
506    }
507}