cidr_map/
set.rs

1use crate::map::CidrMap;
2use crate::{parse_cidr, Node};
3pub use cidr::AnyIpCidr;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::net::IpAddr;
7
8#[derive(Serialize, Deserialize, Clone, Default, PartialEq)]
9#[serde(try_from = "Vec<String>", into = "Vec<String>")]
10pub struct CidrSet(CidrMap<()>);
11
12impl CidrSet {
13    pub fn new() -> Self {
14        Self::default()
15    }
16
17    pub fn default_trusted_hosts() -> Self {
18        vec!["127.0.0.1", "::1"].try_into().unwrap()
19    }
20
21    pub fn default_prohibited_hosts() -> Self {
22        vec!["127.0.0.0/8", "::1"].try_into().unwrap()
23    }
24
25    pub fn contains(&self, ip: IpAddr) -> bool {
26        self.0.contains(ip)
27    }
28
29    pub fn insert<T: Ord + Into<AnyIpCidr>>(&mut self, value: T) {
30        self.0.insert(value.into(), ());
31    }
32}
33
34impl Debug for CidrSet {
35    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
36        use std::fmt::DebugSet;
37        let mut set = fmt.debug_set();
38
39        fn walk(node: &Node<()>, set: &mut DebugSet) {
40            match node {
41                Node::InnerNode(inner) => {
42                    walk(&inner.children.left, set);
43                    walk(&inner.children.right, set);
44                }
45                Node::Leaf(leaf) => {
46                    set.entry(&leaf.key.to_string());
47                }
48            }
49        }
50
51        if let Some(root) = &self.0.root() {
52            walk(root, &mut set);
53        }
54
55        set.finish()
56    }
57}
58
59impl<T: Ord + Into<AnyIpCidr>, const N: usize> From<[T; N]> for CidrSet {
60    /// Converts a `[T; N]` into a `CidrSet`.
61    fn from(mut arr: [T; N]) -> Self {
62        if N == 0 {
63            return CidrSet::new();
64        }
65
66        // use stable sort to preserve the insertion order.
67        arr.sort();
68        let iter = IntoIterator::into_iter(arr); //.map(|k| k.into());
69        iter.collect()
70    }
71}
72
73impl<S> FromIterator<S> for CidrSet
74where
75    S: Into<AnyIpCidr>,
76{
77    fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
78        let mut set = CidrMap::new();
79        for entry in iter {
80            set.insert(entry.into(), ());
81        }
82        Self(set)
83    }
84}
85
86impl TryFrom<Vec<&str>> for CidrSet {
87    type Error = String;
88
89    fn try_from(v: Vec<&str>) -> Result<Self, String> {
90        let mut set = CidrMap::new();
91        let mut problems = vec![];
92        for entry in v {
93            match parse_cidr(entry) {
94                Ok(cidr) => {
95                    set.insert(cidr, ());
96                }
97                Err(err) => {
98                    problems.push(format!("{entry}: {err:#}"));
99                }
100            }
101        }
102        if problems.is_empty() {
103            Ok(Self(set))
104        } else {
105            Err(problems.join(", "))
106        }
107    }
108}
109
110impl TryFrom<Vec<String>> for CidrSet {
111    type Error = String;
112
113    fn try_from(v: Vec<std::string::String>) -> Result<Self, String> {
114        let mut set = CidrMap::new();
115        let mut problems = vec![];
116        for entry in v {
117            match parse_cidr(&entry) {
118                Ok(cidr) => {
119                    set.insert(cidr, ());
120                }
121                Err(err) => {
122                    problems.push(format!("{entry}: {err:#}"));
123                }
124            }
125        }
126        if problems.is_empty() {
127            Ok(Self(set))
128        } else {
129            Err(problems.join(", "))
130        }
131    }
132}
133
134impl From<CidrSet> for Vec<String> {
135    fn from(val: CidrSet) -> Self {
136        let mut result = vec![];
137        for (key, _unit) in val.0.iter() {
138            result.push(key.to_string());
139        }
140        result
141    }
142}
143
144impl From<Vec<AnyIpCidr>> for CidrSet {
145    fn from(entries: Vec<AnyIpCidr>) -> Self {
146        entries.into_iter().collect()
147    }
148}
149
150impl From<CidrSet> for Vec<AnyIpCidr> {
151    fn from(val: CidrSet) -> Self {
152        let mut result = vec![];
153        for (key, _unit) in val.0.iter() {
154            result.push(*key);
155        }
156        result
157    }
158}
159
160#[cfg(test)]
161mod test {
162    use super::*;
163
164    #[test]
165    fn cidrset_any() {
166        let empty_set = CidrSet::new();
167        let set_with_any: CidrSet = [AnyIpCidr::Any].into();
168
169        assert!(!empty_set.contains("127.0.0.1".parse().unwrap()));
170        assert!(set_with_any.contains("127.0.0.1".parse().unwrap()));
171    }
172
173    #[test]
174    fn cidrset() {
175        let set: CidrSet = [
176            parse_cidr("127.0.0.1").unwrap(),
177            parse_cidr("::1").unwrap(),
178            parse_cidr("192.168.1.0/24").unwrap(),
179            // This entry is overlapped by the preceding entry
180            parse_cidr("192.168.1.24").unwrap(),
181            parse_cidr("192.168.3.0/28").unwrap(),
182            parse_cidr("10.0.3.0/24").unwrap(),
183            parse_cidr("10.0.4.0/24").unwrap(),
184            parse_cidr("10.0.7.0/24").unwrap(),
185        ]
186        .into();
187
188        assert!(set.contains("127.0.0.1".parse().unwrap()));
189        assert!(!set.contains("127.0.0.2".parse().unwrap()));
190        assert!(set.contains("::1".parse().unwrap()));
191
192        assert!(!set.contains("192.168.2.1".parse().unwrap()));
193
194        assert!(set.contains("192.168.1.0".parse().unwrap()));
195        assert!(set.contains("192.168.1.1".parse().unwrap()));
196        assert!(set.contains("192.168.1.100".parse().unwrap()));
197        assert!(set.contains("192.168.1.24".parse().unwrap()));
198
199        assert!(set.contains("192.168.3.0".parse().unwrap()));
200        assert!(!set.contains("192.168.3.16".parse().unwrap()));
201
202        // Note that the snapshot does not contain 192.168.1.24/32; that
203        // overlaps with the broader 192.168.1.0/24 so is "lost"
204        // when extracting the information from the set
205        k9::snapshot!(
206            &set,
207            r#"
208{
209    "10.0.3.0/24",
210    "10.0.4.0/24",
211    "10.0.7.0/24",
212    "127.0.0.1",
213    "192.168.1.0/24",
214    "192.168.3.0/28",
215    "::1",
216}
217"#
218        );
219    }
220}