cidr_map/
set.rs

1use crate::map::CidrMap;
2use crate::{parse_cidr, Node};
3pub use cidr::AnyIpCidr;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::net::IpAddr;
7
8#[derive(Serialize, Deserialize, Clone, Default, PartialEq)]
9#[serde(try_from = "Vec<String>", into = "Vec<String>")]
10pub struct CidrSet(CidrMap<()>);
11
12impl CidrSet {
13    pub fn new() -> Self {
14        Self::default()
15    }
16
17    pub fn default_trusted_hosts() -> Self {
18        vec!["127.0.0.1", "::1"].try_into().unwrap()
19    }
20
21    pub fn default_prohibited_hosts() -> Self {
22        vec!["127.0.0.0/8", "::1", "0.0.0.0", "::"]
23            .try_into()
24            .unwrap()
25    }
26
27    pub fn contains(&self, ip: IpAddr) -> bool {
28        self.0.contains(ip)
29    }
30
31    pub fn insert<T: Ord + Into<AnyIpCidr>>(&mut self, value: T) {
32        self.0.insert(value.into(), ());
33    }
34}
35
36impl Debug for CidrSet {
37    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
38        use std::fmt::DebugSet;
39        let mut set = fmt.debug_set();
40
41        fn walk(node: &Node<()>, set: &mut DebugSet) {
42            match node {
43                Node::InnerNode(inner) => {
44                    walk(&inner.children.left, set);
45                    walk(&inner.children.right, set);
46                }
47                Node::Leaf(leaf) => {
48                    set.entry(&leaf.key.to_string());
49                }
50            }
51        }
52
53        if let Some(root) = &self.0.root() {
54            walk(root, &mut set);
55        }
56
57        set.finish()
58    }
59}
60
61impl<T: Ord + Into<AnyIpCidr>, const N: usize> From<[T; N]> for CidrSet {
62    /// Converts a `[T; N]` into a `CidrSet`.
63    fn from(mut arr: [T; N]) -> Self {
64        if N == 0 {
65            return CidrSet::new();
66        }
67
68        // use stable sort to preserve the insertion order.
69        arr.sort();
70        let iter = IntoIterator::into_iter(arr); //.map(|k| k.into());
71        iter.collect()
72    }
73}
74
75impl<S> FromIterator<S> for CidrSet
76where
77    S: Into<AnyIpCidr>,
78{
79    fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
80        let mut set = CidrMap::new();
81        for entry in iter {
82            set.insert(entry.into(), ());
83        }
84        Self(set)
85    }
86}
87
88impl TryFrom<Vec<&str>> for CidrSet {
89    type Error = String;
90
91    fn try_from(v: Vec<&str>) -> Result<Self, String> {
92        let mut set = CidrMap::new();
93        let mut problems = vec![];
94        for entry in v {
95            match parse_cidr(entry) {
96                Ok(cidr) => {
97                    set.insert(cidr, ());
98                }
99                Err(err) => {
100                    problems.push(format!("{entry}: {err:#}"));
101                }
102            }
103        }
104        if problems.is_empty() {
105            Ok(Self(set))
106        } else {
107            Err(problems.join(", "))
108        }
109    }
110}
111
112impl TryFrom<Vec<String>> for CidrSet {
113    type Error = String;
114
115    fn try_from(v: Vec<std::string::String>) -> Result<Self, String> {
116        let mut set = CidrMap::new();
117        let mut problems = vec![];
118        for entry in v {
119            match parse_cidr(&entry) {
120                Ok(cidr) => {
121                    set.insert(cidr, ());
122                }
123                Err(err) => {
124                    problems.push(format!("{entry}: {err:#}"));
125                }
126            }
127        }
128        if problems.is_empty() {
129            Ok(Self(set))
130        } else {
131            Err(problems.join(", "))
132        }
133    }
134}
135
136impl From<CidrSet> for Vec<String> {
137    fn from(val: CidrSet) -> Self {
138        let mut result = vec![];
139        for (key, _unit) in val.0.iter() {
140            result.push(key.to_string());
141        }
142        result
143    }
144}
145
146impl From<Vec<AnyIpCidr>> for CidrSet {
147    fn from(entries: Vec<AnyIpCidr>) -> Self {
148        entries.into_iter().collect()
149    }
150}
151
152impl From<CidrSet> for Vec<AnyIpCidr> {
153    fn from(val: CidrSet) -> Self {
154        let mut result = vec![];
155        for (key, _unit) in val.0.iter() {
156            result.push(*key);
157        }
158        result
159    }
160}
161
162#[cfg(test)]
163mod test {
164    use super::*;
165
166    #[test]
167    fn cidrset_any() {
168        let empty_set = CidrSet::new();
169        let set_with_any: CidrSet = [AnyIpCidr::Any].into();
170
171        assert!(!empty_set.contains("127.0.0.1".parse().unwrap()));
172        assert!(set_with_any.contains("127.0.0.1".parse().unwrap()));
173    }
174
175    #[test]
176    fn cidrset() {
177        let set: CidrSet = [
178            parse_cidr("127.0.0.1").unwrap(),
179            parse_cidr("::1").unwrap(),
180            parse_cidr("192.168.1.0/24").unwrap(),
181            // This entry is overlapped by the preceding entry
182            parse_cidr("192.168.1.24").unwrap(),
183            parse_cidr("192.168.3.0/28").unwrap(),
184            parse_cidr("10.0.3.0/24").unwrap(),
185            parse_cidr("10.0.4.0/24").unwrap(),
186            parse_cidr("10.0.7.0/24").unwrap(),
187        ]
188        .into();
189
190        assert!(set.contains("127.0.0.1".parse().unwrap()));
191        assert!(!set.contains("127.0.0.2".parse().unwrap()));
192        assert!(set.contains("::1".parse().unwrap()));
193
194        assert!(!set.contains("192.168.2.1".parse().unwrap()));
195
196        assert!(set.contains("192.168.1.0".parse().unwrap()));
197        assert!(set.contains("192.168.1.1".parse().unwrap()));
198        assert!(set.contains("192.168.1.100".parse().unwrap()));
199        assert!(set.contains("192.168.1.24".parse().unwrap()));
200
201        assert!(set.contains("192.168.3.0".parse().unwrap()));
202        assert!(!set.contains("192.168.3.16".parse().unwrap()));
203
204        // Note that the snapshot does not contain 192.168.1.24/32; that
205        // overlaps with the broader 192.168.1.0/24 so is "lost"
206        // when extracting the information from the set
207        k9::snapshot!(
208            &set,
209            r#"
210{
211    "10.0.3.0/24",
212    "10.0.4.0/24",
213    "10.0.7.0/24",
214    "127.0.0.1",
215    "192.168.1.0/24",
216    "192.168.3.0/28",
217    "::1",
218}
219"#
220        );
221    }
222}