1use bitstring::BitString;
8pub use cidr::{AnyIpCidr, IpCidr};
9#[cfg(feature = "lua")]
10use config::{any_err, get_or_create_sub_module};
11#[cfg(feature = "lua")]
12use mlua::prelude::LuaUserData;
13#[cfg(feature = "lua")]
14use mlua::{FromLua, Lua, MetaMethod, UserDataMethods};
15#[cfg(feature = "lua")]
16use mod_memoize::CacheValue;
17use serde::de::{MapAccess, Visitor};
18use serde::ser::SerializeMap;
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20use std::fmt::Debug;
21use std::net::IpAddr;
22use std::str::FromStr;
23
24#[derive(Clone, PartialEq)]
25pub struct CidrMap<V>
26where
27 V: Clone,
28{
29 root: Option<Node<V>>,
30}
31
32impl<V> Debug for CidrMap<V>
33where
34 V: Clone + Debug,
35{
36 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
37 use std::fmt::DebugMap;
38 let mut map = fmt.debug_map();
39
40 fn walk<V: Clone + Debug>(node: &Node<V>, map: &mut DebugMap) {
41 match node {
42 Node::InnerNode(inner) => {
43 walk(&inner.children.left, map);
44 walk(&inner.children.right, map);
45 }
46 Node::Leaf(leaf) => {
47 map.key(&leaf.key.to_string());
48 map.value(&leaf.value);
49 }
50 }
51 }
52
53 if let Some(root) = &self.root {
54 walk(root, &mut map);
55 }
56
57 map.finish()
58 }
59}
60
61struct MapVis<T>
62where
63 T: Clone + PartialEq,
64{
65 map: CidrMap<T>,
66}
67
68impl<'de, T> Visitor<'de> for MapVis<T>
69where
70 T: Clone + PartialEq + Deserialize<'de>,
71{
72 type Value = CidrMap<T>;
73
74 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
75 formatter.write_str("a CidrMap")
76 }
77
78 fn visit_map<M>(mut self, mut access: M) -> Result<Self::Value, M::Error>
79 where
80 M: MapAccess<'de>,
81 {
82 while let Some((key, value)) = access.next_entry()? {
83 self.map.insert(key, value);
84 }
85
86 Ok(self.map)
87 }
88}
89
90impl<'de, V> Deserialize<'de> for CidrMap<V>
91where
92 V: Clone + PartialEq + Deserialize<'de>,
93{
94 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
95 where
96 D: Deserializer<'de>,
97 {
98 deserializer.deserialize_map(MapVis {
99 map: CidrMap::new(),
100 })
101 }
102}
103
104impl<V> Serialize for CidrMap<V>
105where
106 V: Clone + PartialEq + Serialize,
107{
108 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
109 where
110 S: Serializer,
111 {
112 let mut map = serializer.serialize_map(None)?;
113 for (k, v) in self.iter() {
114 map.serialize_entry(k, v)?;
115 }
116 map.end()
117 }
118}
119
120#[derive(Debug, Clone, PartialEq)]
123pub enum Node<V>
124where
125 V: Clone,
126{
127 InnerNode(InnerNode<V>),
129 Leaf(Leaf<V>),
131}
132
133#[derive(Clone, Debug, PartialEq)]
135pub struct Leaf<V>
136where
137 V: Clone,
138{
139 pub key: AnyIpCidr,
140 pub value: V,
141}
142
143#[derive(Clone, Debug, PartialEq)]
145pub struct InnerNode<V>
146where
147 V: Clone,
148{
149 pub(crate) key: AnyIpCidr,
150 pub(crate) children: Box<Children<V>>,
151}
152
153#[derive(Clone, Debug, PartialEq)]
154pub(crate) struct Children<V>
155where
156 V: Clone,
157{
158 pub(crate) left: Node<V>,
159 pub(crate) right: Node<V>,
160}
161
162impl<V> InnerNode<V>
163where
164 V: Clone,
165{
166 pub fn key(&self) -> &AnyIpCidr {
167 &self.key
168 }
169
170 pub fn pick_side<'a>(&'a self, subkey: &AnyIpCidr) -> &'a Node<V> {
171 if subkey.get(self.key.len()) {
172 &self.children.right
173 } else {
174 &self.children.left
175 }
176 }
177
178 pub fn pick_side_mut<'a>(&'a mut self, subkey: &AnyIpCidr) -> &'a mut Node<V> {
179 if subkey.get(self.key.len()) {
180 &mut self.children.right
181 } else {
182 &mut self.children.left
183 }
184 }
185
186 pub fn left(&self) -> &Node<V> {
187 &self.children.left
188 }
189
190 pub fn right(&self) -> &Node<V> {
191 &self.children.right
192 }
193}
194
195impl<V> Node<V>
196where
197 V: Clone,
198{
199 fn new_leaf(key: AnyIpCidr, value: V) -> Self {
200 Self::Leaf(Leaf { key, value })
201 }
202
203 fn new_children_unknown_order(
204 shared_prefix_len: usize,
205 a: Node<V>,
206 b: Node<V>,
207 ) -> Box<Children<V>> {
208 let a_right = a.key().get(shared_prefix_len);
209 assert_eq!(!a_right, b.key().get(shared_prefix_len));
210 if a_right {
211 Box::new(Children { left: b, right: a })
212 } else {
213 Box::new(Children { left: a, right: b })
214 }
215 }
216
217 fn new_inner_unknown_order(shared_prefix_len: usize, a: Node<V>, b: Node<V>) -> Node<V> {
218 let mut key = *a.key();
219 key.clip(shared_prefix_len);
220 Node::InnerNode(InnerNode {
221 key,
222 children: Self::new_children_unknown_order(shared_prefix_len, a, b),
223 })
224 }
225
226 pub fn key(&self) -> &AnyIpCidr {
228 match *self {
229 Node::Leaf(ref leaf) => &leaf.key,
230 Node::InnerNode(ref inner) => &inner.key,
231 }
232 }
233
234 fn leaf_ref(&self) -> Option<&Leaf<V>> {
235 match *self {
236 Node::Leaf(ref leaf) => Some(leaf),
237 _ => None,
238 }
239 }
240
241 fn convert_leaf(&mut self, key_len: usize, value: V) {
244 *self = match self {
245 Node::Leaf(leaf) => {
246 let mut leaf = leaf.clone();
247 leaf.key.clip(key_len);
248 leaf.value = value;
249 Node::Leaf(leaf)
250 }
251 Node::InnerNode(inner) => {
252 let mut key = inner.key;
253 key.clip(key_len);
254 Self::new_leaf(key, value)
255 }
256 };
257 }
258
259 fn insert_uncompressed(&mut self, key: AnyIpCidr, value: V)
260 where
261 V: Clone,
262 {
263 let (self_key_len, shared_prefix_len) = {
264 let key_ref = self.key();
265 (key_ref.len(), key_ref.shared_prefix_len(&key))
266 };
267
268 if shared_prefix_len == key.len() {
269 self.convert_leaf(shared_prefix_len, value);
272 } else if shared_prefix_len < self_key_len {
273 debug_assert!(shared_prefix_len < key.len());
274 *self = Self::new_inner_unknown_order(
276 shared_prefix_len,
277 self.clone(),
278 Self::new_leaf(key, value),
279 );
280 } else {
281 debug_assert!(shared_prefix_len == self_key_len);
282 debug_assert!(shared_prefix_len < key.len());
283 match *self {
285 Node::Leaf(_) => {
286 let old_value = self.leaf_ref().unwrap().value.clone();
288 let mut new_node = Self::new_leaf(key, value);
289 for l in (shared_prefix_len..key.len()).rev() {
290 let mut other_key = key;
291 other_key.clip(l + 1);
292 other_key.flip(l);
293 new_node = Self::new_inner_unknown_order(
294 l,
295 new_node,
296 Self::new_leaf(other_key, old_value.clone()),
297 );
298 }
299 *self = new_node;
300 }
301 Node::InnerNode(ref mut inner) => {
302 inner.pick_side_mut(&key).insert_uncompressed(key, value);
303 }
304 }
305 }
306 }
307
308 fn insert(&mut self, key: AnyIpCidr, value: V)
309 where
310 V: Clone + PartialEq,
311 {
312 let (self_key_len, shared_prefix_len) = {
313 let key_ref = self.key();
314 (key_ref.len(), key_ref.shared_prefix_len(&key))
315 };
316
317 if shared_prefix_len == key.len() {
318 self.convert_leaf(shared_prefix_len, value);
321 } else if shared_prefix_len < self_key_len {
323 debug_assert!(shared_prefix_len < key.len());
324 if shared_prefix_len + 1 == self_key_len && shared_prefix_len + 1 == key.len() {
325 if let Node::Leaf(ref mut this) = *self {
326 if this.value == value {
327 this.key.clip(shared_prefix_len);
330 return; }
332 }
333 }
334
335 *self = Self::new_inner_unknown_order(
337 shared_prefix_len,
338 self.clone(),
339 Self::new_leaf(key, value),
340 );
341 } else {
344 debug_assert!(shared_prefix_len == self_key_len);
345 debug_assert!(shared_prefix_len < key.len());
346 match *self {
348 Node::Leaf(_) => {
349 let new_node = {
351 let old_value = &self.leaf_ref().unwrap().value;
352 if *old_value == value {
353 return;
355 }
356 let mut new_node = Self::new_leaf(key, value);
357 for l in (shared_prefix_len..key.len()).rev() {
358 let mut other_key = key;
359 other_key.clip(l + 1);
360 other_key.flip(l);
361 new_node = Self::new_inner_unknown_order(
362 l,
363 new_node,
364 Self::new_leaf(other_key, old_value.clone()),
365 );
366 }
367 new_node
368 };
369 *self = new_node;
370 return;
372 }
373 Node::InnerNode(ref mut inner) => {
374 inner.pick_side_mut(&key).insert(key, value);
375 }
376 }
377 self.compress();
379 }
380 }
381
382 fn compress(&mut self)
383 where
384 V: PartialEq,
385 {
386 let self_key_len = self.key().len();
387
388 let compress = match *self {
393 Node::InnerNode(ref inner) => {
394 let left_value = match inner.children.left {
395 Node::Leaf(ref leaf) if leaf.key.len() == self_key_len + 1 => &leaf.value,
396 _ => return, };
398 let right_value = match inner.children.right {
399 Node::Leaf(ref leaf) if leaf.key.len() == self_key_len + 1 => &leaf.value,
400 _ => return, };
402 left_value == right_value
403 }
404 Node::Leaf(_) => return, };
406 if compress {
407 *self = match self {
408 Node::InnerNode(inner) => match &inner.children.left {
410 Node::Leaf(leaf) => Node::Leaf(Leaf {
411 key: inner.key,
412 value: leaf.value.clone(),
413 }),
414 _ => unreachable!(),
415 },
416 _ => unreachable!(),
417 };
418 }
419 }
420}
421
422impl<V> Default for CidrMap<V>
423where
424 V: Clone,
425{
426 fn default() -> Self {
427 Self { root: None }
428 }
429}
430
431impl<V> CidrMap<V>
432where
433 V: Clone,
434{
435 pub fn new() -> Self {
436 Self { root: None }
437 }
438
439 pub fn contains(&self, ip: IpAddr) -> bool {
440 self.get_prefix_match(ip).is_some()
441 }
442
443 pub fn get_prefix_match(&self, ip: IpAddr) -> Option<&V> {
444 let key: AnyIpCidr = IpCidr::new_host(ip).into();
445 self.get_prefix_match_cidr(&key)
446 }
447
448 pub fn get_prefix_match_cidr(&self, key: &AnyIpCidr) -> Option<&V> {
449 let node = self.root.as_ref()?;
450 Self::find_item(node, key)
451 }
452
453 fn find_item<'a>(node: &'a Node<V>, ip: &AnyIpCidr) -> Option<&'a V> {
454 match node {
455 Node::Leaf(leaf) => {
456 if leaf.key.contains(&ip.first_address().unwrap()) {
457 Some(&leaf.value)
458 } else {
459 None
460 }
461 }
462 Node::InnerNode(inner) => Self::find_item(inner.pick_side(ip), ip),
463 }
464 }
465
466 pub fn insert_uncompressed(&mut self, key: AnyIpCidr, value: V)
471 where
472 V: Clone,
473 {
474 match self.root {
475 None => {
476 self.root = Some(Node::new_leaf(key, value));
477 }
478 Some(ref mut node) => {
479 node.insert_uncompressed(key, value);
480 }
481 }
482 }
483
484 pub fn insert(&mut self, key: AnyIpCidr, value: V)
487 where
488 V: Clone + PartialEq,
489 {
490 match self.root {
491 None => {
492 self.root = Some(Node::new_leaf(key, value));
493 }
494 Some(ref mut node) => {
495 node.insert(key, value);
496 }
497 }
498 }
499
500 pub fn root(&self) -> Option<&Node<V>> {
504 self.root.as_ref()
505 }
506
507 pub fn iter(&self) -> Iter<V> {
509 Iter::new(self)
510 }
511}
512
513#[derive(Clone, Copy, PartialEq, Eq)]
514enum Direction {
515 Left,
516 Right,
517 Up,
518}
519
520pub struct Iter<'a, V: 'a>
522where
523 V: Clone,
524{
525 stack: Vec<(Direction, &'a Node<V>)>,
526}
527
528impl<'a, V> Iter<'a, V>
529where
530 V: Clone,
531{
532 pub fn new(tree: &'a CidrMap<V>) -> Self {
534 match tree.root() {
535 None => Iter { stack: Vec::new() },
536 Some(node) => Iter {
537 stack: vec![(Direction::Left, node)],
538 },
539 }
540 }
541}
542
543impl<'a, V> Iterator for Iter<'a, V>
544where
545 V: Clone,
546{
547 type Item = (&'a AnyIpCidr, &'a V);
548
549 fn next(&mut self) -> Option<Self::Item> {
550 if self.stack.is_empty() {
551 return None;
552 }
553
554 while Direction::Up == self.stack[self.stack.len() - 1].0 {
556 if 1 == self.stack.len() {
557 self.stack.clear();
558 return None;
559 }
560
561 self.stack.pop();
562 debug_assert!(!self.stack.is_empty());
564 }
565
566 loop {
567 let top = self.stack.len() - 1;
568 let (dir, node) = self.stack[top];
569
570 debug_assert!(!self.stack.is_empty());
571 match dir {
573 Direction::Left => match *node {
574 Node::InnerNode(ref inner) => {
575 self.stack[top].0 = Direction::Right;
576 self.stack.push((Direction::Left, inner.left()));
577 }
578 Node::Leaf(ref leaf) => {
579 self.stack[top].0 = Direction::Up;
580 return Some((&leaf.key, &leaf.value));
581 }
582 },
583 Direction::Right => match *node {
584 Node::InnerNode(ref inner) => {
585 self.stack[top].0 = Direction::Up;
586 self.stack.push((Direction::Left, inner.right()));
587 }
588 Node::Leaf(_) => unreachable!(),
589 },
590 Direction::Up => unreachable!(),
591 }
592 }
593 }
594}
595
596impl<S, V: Clone + PartialEq> FromIterator<(S, V)> for CidrMap<V>
597where
598 S: Into<AnyIpCidr>,
599{
600 fn from_iter<I: IntoIterator<Item = (S, V)>>(iter: I) -> Self {
601 let mut map = CidrMap::new();
602 for (key, value) in iter {
603 map.insert(key.into(), value);
604 }
605 map
606 }
607}
608
609impl<T: Ord + Into<AnyIpCidr>, const N: usize, V: Clone + Ord> From<[(T, V); N]> for CidrMap<V> {
610 fn from(mut arr: [(T, V); N]) -> Self {
612 if N == 0 {
613 return CidrMap::new();
614 }
615
616 arr.sort();
618 let iter = IntoIterator::into_iter(arr);
619 iter.collect()
620 }
621}
622
623impl<V: Clone> From<CidrMap<V>> for Vec<(AnyIpCidr, V)> {
624 fn from(val: CidrMap<V>) -> Self {
625 let mut result = vec![];
626 for (key, value) in val.iter() {
627 result.push((*key, value.clone()));
628 }
629 result
630 }
631}
632
633#[cfg(feature = "lua")]
634impl LuaUserData for CidrMap<CacheValue> {
635 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
636 mod_memoize::Memoized::impl_memoize(methods);
637 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
638 let key = parse_cidr_from_ip_and_or_port(&key).map_err(any_err)?;
639 if let Some(value) = this.get_prefix_match_cidr(&key) {
640 let value = value.as_lua(lua)?;
641 Ok(Some(value))
642 } else {
643 Ok(None)
644 }
645 });
646 methods.add_meta_method_mut(
647 MetaMethod::NewIndex,
648 |lua, this, (key, value): (String, mlua::Value)| {
649 let key = parse_cidr_from_ip_and_or_port(&key).map_err(any_err)?;
650 let value = CacheValue::from_lua(value, lua)?;
651 this.insert(key, value);
652 Ok(())
653 },
654 );
655 }
656}
657
658#[cfg(feature = "lua")]
659fn parse_cidr_from_ip_and_or_port(s: &str) -> anyhow::Result<AnyIpCidr> {
660 match parse_cidr(s) {
661 Ok(c) => Ok(c),
662 Err(err) => {
663 if s.starts_with('[') {
664 if let Some((ip, _port)) = s[1..].split_once(']') {
665 return parse_cidr(ip).map_err(|err| {
666 anyhow::anyhow!(
667 "failed to parse '{ip}', the \
668 []-enclosed portion of '{s}', as an IP address: {err:#}"
669 )
670 });
671 }
672 }
673 if let Some((ip, _port)) = s.rsplit_once(':') {
674 return parse_cidr(ip).map_err(|err| {
675 anyhow::anyhow!(
676 "failed to parse '{ip}', the \
677 :-delimited portion of '{s}', as an IP address: {err:#}"
678 )
679 });
680 }
681 Err(err)
682 }
683 }
684}
685
686pub fn parse_cidr(s: &str) -> anyhow::Result<AnyIpCidr> {
692 AnyIpCidr::from_str(s).map_err(|err| {
693 match cidr::parsers::parse_any_cidr_full_ignore_hostbits(
694 s,
695 std::str::FromStr::from_str,
696 std::str::FromStr::from_str,
697 ) {
698 Ok(loose) => {
699 anyhow::anyhow!("{s} is not a valid CIDR: {err:#}. Did you mean {loose}?")
700 }
701 Err(err) => {
702 anyhow::anyhow!("{s} is not a valid CIDR: {err:#}")
703 }
704 }
705 })
706}
707
708#[cfg(feature = "lua")]
709pub fn register(lua: &Lua) -> anyhow::Result<()> {
710 use std::collections::HashMap;
711 let cidr_mod = get_or_create_sub_module(lua, "cidr")?;
712
713 cidr_mod.set(
714 "make_map",
715 lua.create_function(|lua, value: Option<HashMap<String, mlua::Value>>| {
716 let mut cmap: CidrMap<mod_memoize::CacheValue> = CidrMap::new();
717
718 if let Some(value) = value {
719 for (k, v) in value {
720 let k = parse_cidr_from_ip_and_or_port(&k).map_err(any_err)?;
721 let v = CacheValue::from_lua(v, lua)?;
722 cmap.insert(k, v);
723 }
724 }
725
726 Ok(cmap)
727 })?,
728 )?;
729
730 Ok(())
731}
732
733#[cfg(test)]
734mod test {
735 use super::*;
736
737 #[test]
738 fn test_parse_error_message() {
739 assert_eq!(
740 parse_cidr("10.0.0.1/24").unwrap_err().to_string(),
741 "10.0.0.1/24 is not a valid CIDR: host part of address was not zero. Did you mean 10.0.0.0/24?"
742 );
743 }
744
745 #[test]
746 fn cidrmap() {
747 let set: CidrMap<&str> = [
748 (parse_cidr("127.0.0.1").unwrap(), "loopbackv4"),
749 (parse_cidr("::1").unwrap(), "loopbackv6"),
750 (parse_cidr("192.168.1.0/24").unwrap(), ".1"),
751 (parse_cidr("192.168.1.24").unwrap(), ".1"),
753 (parse_cidr("192.168.3.0/28").unwrap(), ".3"),
754 (parse_cidr("192.168.3.2").unwrap(), ".3.split"),
755 (parse_cidr("10.0.3.0/24").unwrap(), "10.3"),
756 (parse_cidr("10.0.4.0/24").unwrap(), "10.4"),
757 (parse_cidr("10.0.7.0/24").unwrap(), "10.7"),
758 ]
759 .into();
760
761 fn get<'a>(set: &'a CidrMap<&str>, key: &str) -> Option<&'a str> {
762 let key = key.parse().unwrap();
763 set.get_prefix_match(key).copied()
764 }
765
766 assert_eq!(get(&set, "127.0.0.1"), Some("loopbackv4"));
767 assert_eq!(get(&set, "127.0.0.2"), None);
768 assert_eq!(get(&set, "::1"), Some("loopbackv6"));
769
770 assert_eq!(get(&set, "192.168.2.1"), None);
771
772 assert_eq!(get(&set, "192.168.1.0"), Some(".1"));
773 assert_eq!(get(&set, "192.168.1.1"), Some(".1"));
774 assert_eq!(get(&set, "192.168.1.100"), Some(".1"));
775 assert_eq!(get(&set, "192.168.1.24"), Some(".1"));
776
777 assert_eq!(get(&set, "192.168.3.0"), Some(".3"));
778 assert_eq!(get(&set, "192.168.3.16"), None);
779 assert_eq!(get(&set, "192.168.3.2"), Some(".3.split"));
780
781 k9::snapshot!(
788 &set,
789 r#"
790{
791 "10.0.3.0/24": "10.3",
792 "10.0.4.0/24": "10.4",
793 "10.0.7.0/24": "10.7",
794 "127.0.0.1": "loopbackv4",
795 "192.168.1.0/24": ".1",
796 "192.168.3.0/31": ".3",
797 "192.168.3.2": ".3.split",
798 "192.168.3.3": ".3",
799 "192.168.3.4/30": ".3",
800 "192.168.3.8/29": ".3",
801 "::1": "loopbackv6",
802}
803"#
804 );
805 }
806}