1use bstr::{BStr, ByteSlice};
2use hickory_resolver::Name;
3use nom::branch::alt;
4use nom::bytes::complete::{take_while1, take_while_m_n};
5use nom::combinator::{map_res, opt, recognize};
6use nom::error::{context, ContextError, ErrorKind, FromExternalError, ParseError as _};
7use nom::multi::{many0, many1};
8use nom::sequence::pair;
9use nom::{Input, Parser as _};
10use nom_locate::LocatedSpan;
11use std::fmt::{self, Debug, Write};
12use std::hash::Hash;
13use std::marker::PhantomData;
14use std::net::{Ipv4Addr, Ipv6Addr};
15use std::str::FromStr;
16
17pub type Span<'a> = LocatedSpan<&'a [u8]>;
18pub type IResult<'a, A, B> = nom::IResult<A, B, ParseError<Span<'a>>>;
19
20pub fn make_span(s: &'_ [u8]) -> Span<'_> {
21 Span::new(s)
22}
23
24pub fn tag<E>(tag: &'static str) -> TagParser<E> {
28 TagParser {
29 tag,
30 no_case: false,
31 e: PhantomData,
32 }
33}
34
35pub fn tag_no_case<E>(tag: &'static str) -> TagParser<E> {
36 TagParser {
37 tag,
38 no_case: true,
39 e: PhantomData,
40 }
41}
42
43pub struct TagParser<E> {
45 tag: &'static str,
46 no_case: bool,
47 e: PhantomData<E>,
48}
49
50impl<I, Error: nom::error::ParseError<I> + nom::error::FromExternalError<I, String>> nom::Parser<I>
52 for TagParser<Error>
53where
54 I: nom::Input + nom::Compare<&'static str> + nom::AsBytes,
55{
56 type Output = I;
57 type Error = Error;
58
59 fn process<OM: nom::OutputMode>(
60 &mut self,
61 i: I,
62 ) -> nom::PResult<OM, I, Self::Output, Self::Error> {
63 use nom::error::ErrorKind;
64 use nom::{CompareResult, Err, Mode};
65
66 let tag_len = self.tag.input_len();
67
68 let compare_result = if self.no_case {
69 i.compare_no_case(self.tag)
70 } else {
71 i.compare(self.tag)
72 };
73
74 match compare_result {
75 CompareResult::Ok => Ok((i.take_from(tag_len), OM::Output::bind(|| i.take(tag_len)))),
76 CompareResult::Incomplete => Err(Err::Error(OM::Error::bind(|| {
77 Error::from_external_error(
78 i,
79 ErrorKind::Fail,
80 format!(
81 "expected \"{}\" but ran out of input",
82 self.tag.escape_debug()
83 ),
84 )
85 }))),
86
87 CompareResult::Error => {
88 let available = i.take(i.input_len().min(tag_len));
89 Err(Err::Error(OM::Error::bind(|| {
90 Error::from_external_error(
91 i,
92 ErrorKind::Fail,
93 format!(
94 "expected \"{}\" but found {:?}",
95 self.tag.escape_debug(),
96 BStr::new(available.as_bytes())
97 ),
98 )
99 })))
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum ParseErrorKind {
107 Context(&'static str),
108 Char(char),
109 Nom(ErrorKind),
110 External { kind: ErrorKind, reason: String },
111}
112
113#[derive(Debug)]
114pub struct ParseError<I: Debug> {
115 pub errors: Vec<(I, ParseErrorKind)>,
116}
117
118impl<I: Debug> ContextError<I> for ParseError<I> {
119 fn add_context(input: I, ctx: &'static str, mut other: Self) -> Self {
120 other.errors.push((input, ParseErrorKind::Context(ctx)));
121 other
122 }
123}
124
125impl<I: Debug> nom::error::ParseError<I> for ParseError<I> {
126 fn from_error_kind(input: I, kind: ErrorKind) -> Self {
127 Self {
128 errors: vec![(input, ParseErrorKind::Nom(kind))],
129 }
130 }
131
132 fn append(input: I, kind: ErrorKind, mut other: Self) -> Self {
133 other.errors.push((input, ParseErrorKind::Nom(kind)));
134 other
135 }
136
137 fn from_char(input: I, c: char) -> Self {
138 Self {
139 errors: vec![(input, ParseErrorKind::Char(c))],
140 }
141 }
142}
143
144impl<I: Debug, E: std::fmt::Display> nom::error::FromExternalError<I, E> for ParseError<I> {
145 fn from_external_error(input: I, kind: ErrorKind, err: E) -> Self {
146 Self {
147 errors: vec![(
148 input,
149 ParseErrorKind::External {
150 kind,
151 reason: format!("{err:#}"),
152 },
153 )],
154 }
155 }
156}
157
158pub fn make_context_error<S: Into<String>>(
159 input: Span<'_>,
160 reason: S,
161) -> nom::Err<ParseError<Span<'_>>> {
162 nom::Err::Error(ParseError {
163 errors: vec![(
164 input,
165 ParseErrorKind::External {
166 kind: nom::error::ErrorKind::Fail,
167 reason: reason.into(),
168 },
169 )],
170 })
171}
172
173pub fn explain_nom(input: Span, err: nom::Err<ParseError<Span<'_>>>) -> String {
174 match err {
175 nom::Err::Error(e) => {
176 let mut result = String::new();
177 let mut lines_shown = vec![];
178
179 for (span, kind) in e.errors.iter() {
180 if input.is_empty() {
181 match kind {
182 ParseErrorKind::Char(c) => {
183 write!(&mut result, "Error expected '{c}', got empty input\n\n")
184 }
185 ParseErrorKind::Context(s) => {
186 write!(&mut result, "Error in {s}, got empty input\n\n")
187 }
188 ParseErrorKind::External { kind, reason } => {
189 write!(&mut result, "Error {reason} {kind:?}, got empty input\n\n")
190 }
191 ParseErrorKind::Nom(e) => {
192 write!(&mut result, "Error in {e:?}, got empty input\n\n")
193 }
194 }
195 .ok();
196 continue;
197 }
198
199 let line_number = span.location_line();
200 let input_line = span.get_line_beginning();
201 let mut line = String::new();
205 for (start, end, c) in input_line.char_indices() {
206 let c = match c {
207 '\t' => '\u{2409}',
208 '\r' => '\u{240d}',
209 '\n' => '\u{240a}',
210 c => c,
211 };
212
213 if c == std::char::REPLACEMENT_CHARACTER {
214 let bytes = &input_line[start..end];
215 for b in bytes.iter() {
216 line.push_str(&format!("\\x{b:02X}"));
217 }
218 } else {
219 line.push(c);
220 }
221 }
222
223 let column = span.get_utf8_column();
224
225 lines_shown.push(line_number);
226
227 let mut caret = " ".repeat(column.saturating_sub(1));
228 caret.push('^');
229 for _ in 1..span.fragment().len() {
230 caret.push('_')
231 }
232
233 match kind {
234 ParseErrorKind::Char(expected) => {
235 if let Some(actual) = span.fragment().chars().next() {
236 write!(
237 &mut result,
238 "Error at line {line_number}:\n\
239 {line}\n\
240 {caret}\n\
241 expected '{expected}', found {actual}\n\n",
242 )
243 } else {
244 write!(
245 &mut result,
246 "Error at line {line_number}:\n\
247 {line}\n\
248 {caret}\n\
249 expected '{expected}', got end of input\n\n",
250 )
251 }
252 }
253 ParseErrorKind::Context(context) => {
254 write!(&mut result, "while parsing {context}\n")
255 }
256 ParseErrorKind::External { kind: _, reason } => {
257 write!(
258 &mut result,
259 "Error at line {line_number}, {reason}:\n\
260 {line}\n\
261 {caret}\n\n",
262 )
263 }
264 ParseErrorKind::Nom(nom_err) => {
265 write!(
266 &mut result,
267 "Error at line {line_number}, in {nom_err:?}:\n\
268 {line}\n\
269 {caret}\n\n",
270 )
271 }
272 }
273 .ok();
274 }
275 result
276 }
277 _ => format!("{err:#}"),
278 }
279}
280
281pub fn utf8_non_ascii(input: Span) -> IResult<Span, Span> {
291 use nom::Err;
292
293 match input.char_indices().next() {
294 Some((start, end, c)) => {
295 let len = end - start;
296 if c as u32 <= 0x7f {
297 return Err(Err::Error(ParseError::from_error_kind(
299 input,
300 ErrorKind::Fail,
301 )));
302 }
303 let slice = &input[start..end];
304 if c == std::char::REPLACEMENT_CHARACTER {
305 let mut verify = [0u8; 4];
306 if slice != c.encode_utf8(&mut verify).as_bytes() {
307 return Err(Err::Error(ParseError::from_error_kind(
310 input,
311 ErrorKind::Fail,
312 )));
313 }
314 }
315 Ok((input.take_from(len), input.take(len)))
317 }
318 None => {
319 Err(Err::Error(ParseError::from_error_kind(
321 input,
322 ErrorKind::Eof,
323 )))
324 }
325 }
326}
327
328fn snum(input: Span) -> IResult<Span, Span> {
329 take_while_m_n(1, 3, |c: u8| c.is_ascii_digit()).parse(input)
330}
331
332pub fn ipv4_address(input: Span) -> IResult<Span, Ipv4Addr> {
333 context(
334 "ipv4_address",
335 map_res(
336 recognize((snum, tag("."), snum, tag("."), snum, tag("."), snum)),
337 |matched| {
338 let v4str = std::str::from_utf8(&matched).expect("can only be ascii");
339 v4str.parse().map_err(|err| {
340 nom::Err::Error(ParseError::from_external_error(
341 input,
342 ErrorKind::Fail,
343 format!("invalid ipv4_address: {err}"),
344 ))
345 })
346 },
347 ),
348 )
349 .parse(input)
350}
351
352pub fn ipv6_address(input: Span) -> IResult<Span, Ipv6Addr> {
353 context(
354 "ipv6_address",
355 map_res(
356 take_while1(|c: u8| c.is_ascii_hexdigit() || c == b':' || c == b'.'),
357 |matched: Span| {
358 let v6str = std::str::from_utf8(&matched).expect("can only be ascii");
359 v6str.parse().map_err(|err| {
360 nom::Err::Error(ParseError::from_external_error(
361 input,
362 ErrorKind::Fail,
363 format!("invalid ipv6_address: {err}"),
364 ))
365 })
366 },
367 ),
368 )
369 .parse(input)
370}
371
372#[derive(Clone, Debug, PartialEq, Eq, Hash)]
376pub struct DomainString(String);
377
378impl fmt::Display for DomainString {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 f.write_str(&self.0)
381 }
382}
383
384impl DomainString {
385 pub fn name(&self) -> Name {
386 Name::from_str_relaxed(&self.0)
387 .expect("cannot construct DomainString with an invalid domain name")
388 }
389
390 pub fn as_str(&self) -> &str {
392 &self.0
393 }
394}
395
396impl FromStr for DomainString {
397 type Err = String;
398 fn from_str(s: &str) -> Result<Self, Self::Err> {
399 let name = Name::from_str_relaxed(s)?;
400 Ok(Self(name.to_ascii()))
401 }
402}
403
404impl From<DomainString> for Name {
405 fn from(val: DomainString) -> Self {
406 val.name()
407 }
408}
409
410impl From<&DomainString> for Name {
411 fn from(val: &DomainString) -> Self {
412 val.name()
413 }
414}
415
416fn let_dig(input: Span) -> IResult<Span, Span> {
418 recognize(alt((
419 take_while_m_n(1, 1, |c: u8| c.is_ascii_alphanumeric()),
420 utf8_non_ascii,
421 )))
422 .parse(input)
423}
424
425fn ldh_str(input: Span) -> IResult<Span, Span> {
431 recognize(many1(alt((
432 take_while_m_n(1, 1, |c: u8| {
433 c.is_ascii_alphanumeric() || c == b'-' || c == b'_'
434 }),
435 utf8_non_ascii,
436 ))))
437 .parse(input)
438}
439
440fn sub_domain(input: Span) -> IResult<Span, Span> {
442 recognize(pair(let_dig, opt(ldh_str))).parse(input)
443}
444
445pub fn domain_name(input: Span) -> IResult<Span, DomainString> {
447 context(
448 "domain-name",
449 map_res(
450 recognize(pair(sub_domain, many0(pair(tag("."), sub_domain)))),
451 |matched: Span| match std::str::from_utf8(&matched) {
452 Ok(s) => s.parse().map_err(|err| {
453 nom::Err::Error(ParseError::from_external_error(
454 input,
455 ErrorKind::Fail,
456 format!("invalid domain name: {err}"),
457 ))
458 }),
459 Err(err) => Err(nom::Err::Error(ParseError::from_external_error(
460 input,
461 ErrorKind::Fail,
462 format!("invalid domain name: {err}"),
463 ))),
464 },
465 ),
466 )
467 .parse(input)
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_ipv4_parse() {
476 let (_, addr) = ipv4_address(make_span(b"192.168.1.1")).unwrap();
478 k9::assert_equal!(addr, Ipv4Addr::new(192, 168, 1, 1));
479 }
480
481 #[test]
482 fn test_ipv6_parse() {
483 let (_, v6a) = ipv6_address(make_span(b"2001:0db8:0000:0000:0000:0000:0000:0001")).unwrap();
486 let (_, v6b) = ipv6_address(make_span(b"2001:db8::1")).unwrap();
487 k9::assert_equal!(v6a, v6b);
488 }
489
490 #[test]
491 fn test_domain_string_partial_eq() {
492 let d1 = DomainString::from_str("EXAMPLE.COM").unwrap();
494 let d2 = DomainString::from_str("example.com").unwrap();
495
496 assert_eq!(d1, d2);
497 }
498
499 #[test]
500 fn test_domain_string_partial_eq_idna() {
501 let d1 = DomainString::from_str("münchen.de").unwrap();
503 let d2 = DomainString::from_str("xn--mnchen-3ya.de").unwrap();
504
505 assert_eq!(d1, d2);
506 }
507}