1use crate::map::CidrMap;
2use crate::{parse_cidr, Node};
3pub use cidr::AnyIpCidr;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::net::IpAddr;
7
8#[derive(Serialize, Deserialize, Clone, Default, PartialEq)]
9#[serde(try_from = "Vec<String>", into = "Vec<String>")]
10pub struct CidrSet(CidrMap<()>);
11
12impl CidrSet {
13 pub fn new() -> Self {
14 Self::default()
15 }
16
17 pub fn default_trusted_hosts() -> Self {
18 vec!["127.0.0.1", "::1"].try_into().unwrap()
19 }
20
21 pub fn default_prohibited_hosts() -> Self {
22 vec!["127.0.0.0/8", "::1", "0.0.0.0", "::"]
23 .try_into()
24 .unwrap()
25 }
26
27 pub fn contains(&self, ip: IpAddr) -> bool {
28 self.0.contains(ip)
29 }
30
31 pub fn insert<T: Ord + Into<AnyIpCidr>>(&mut self, value: T) {
32 self.0.insert(value.into(), ());
33 }
34}
35
36impl Debug for CidrSet {
37 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
38 use std::fmt::DebugSet;
39 let mut set = fmt.debug_set();
40
41 fn walk(node: &Node<()>, set: &mut DebugSet) {
42 match node {
43 Node::InnerNode(inner) => {
44 walk(&inner.children.left, set);
45 walk(&inner.children.right, set);
46 }
47 Node::Leaf(leaf) => {
48 set.entry(&leaf.key.to_string());
49 }
50 }
51 }
52
53 if let Some(root) = &self.0.root() {
54 walk(root, &mut set);
55 }
56
57 set.finish()
58 }
59}
60
61impl<T: Ord + Into<AnyIpCidr>, const N: usize> From<[T; N]> for CidrSet {
62 fn from(mut arr: [T; N]) -> Self {
64 if N == 0 {
65 return CidrSet::new();
66 }
67
68 arr.sort();
70 let iter = IntoIterator::into_iter(arr); iter.collect()
72 }
73}
74
75impl<S> FromIterator<S> for CidrSet
76where
77 S: Into<AnyIpCidr>,
78{
79 fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
80 let mut set = CidrMap::new();
81 for entry in iter {
82 set.insert(entry.into(), ());
83 }
84 Self(set)
85 }
86}
87
88impl TryFrom<Vec<&str>> for CidrSet {
89 type Error = String;
90
91 fn try_from(v: Vec<&str>) -> Result<Self, String> {
92 let mut set = CidrMap::new();
93 let mut problems = vec![];
94 for entry in v {
95 match parse_cidr(entry) {
96 Ok(cidr) => {
97 set.insert(cidr, ());
98 }
99 Err(err) => {
100 problems.push(format!("{entry}: {err:#}"));
101 }
102 }
103 }
104 if problems.is_empty() {
105 Ok(Self(set))
106 } else {
107 Err(problems.join(", "))
108 }
109 }
110}
111
112impl TryFrom<Vec<String>> for CidrSet {
113 type Error = String;
114
115 fn try_from(v: Vec<std::string::String>) -> Result<Self, String> {
116 let mut set = CidrMap::new();
117 let mut problems = vec![];
118 for entry in v {
119 match parse_cidr(&entry) {
120 Ok(cidr) => {
121 set.insert(cidr, ());
122 }
123 Err(err) => {
124 problems.push(format!("{entry}: {err:#}"));
125 }
126 }
127 }
128 if problems.is_empty() {
129 Ok(Self(set))
130 } else {
131 Err(problems.join(", "))
132 }
133 }
134}
135
136impl From<CidrSet> for Vec<String> {
137 fn from(val: CidrSet) -> Self {
138 let mut result = vec![];
139 for (key, _unit) in val.0.iter() {
140 result.push(key.to_string());
141 }
142 result
143 }
144}
145
146impl From<Vec<AnyIpCidr>> for CidrSet {
147 fn from(entries: Vec<AnyIpCidr>) -> Self {
148 entries.into_iter().collect()
149 }
150}
151
152impl From<CidrSet> for Vec<AnyIpCidr> {
153 fn from(val: CidrSet) -> Self {
154 let mut result = vec![];
155 for (key, _unit) in val.0.iter() {
156 result.push(*key);
157 }
158 result
159 }
160}
161
162#[cfg(test)]
163mod test {
164 use super::*;
165
166 #[test]
167 fn cidrset_any() {
168 let empty_set = CidrSet::new();
169 let set_with_any: CidrSet = [AnyIpCidr::Any].into();
170
171 assert!(!empty_set.contains("127.0.0.1".parse().unwrap()));
172 assert!(set_with_any.contains("127.0.0.1".parse().unwrap()));
173 }
174
175 #[test]
176 fn cidrset() {
177 let set: CidrSet = [
178 parse_cidr("127.0.0.1").unwrap(),
179 parse_cidr("::1").unwrap(),
180 parse_cidr("192.168.1.0/24").unwrap(),
181 parse_cidr("192.168.1.24").unwrap(),
183 parse_cidr("192.168.3.0/28").unwrap(),
184 parse_cidr("10.0.3.0/24").unwrap(),
185 parse_cidr("10.0.4.0/24").unwrap(),
186 parse_cidr("10.0.7.0/24").unwrap(),
187 ]
188 .into();
189
190 assert!(set.contains("127.0.0.1".parse().unwrap()));
191 assert!(!set.contains("127.0.0.2".parse().unwrap()));
192 assert!(set.contains("::1".parse().unwrap()));
193
194 assert!(!set.contains("192.168.2.1".parse().unwrap()));
195
196 assert!(set.contains("192.168.1.0".parse().unwrap()));
197 assert!(set.contains("192.168.1.1".parse().unwrap()));
198 assert!(set.contains("192.168.1.100".parse().unwrap()));
199 assert!(set.contains("192.168.1.24".parse().unwrap()));
200
201 assert!(set.contains("192.168.3.0".parse().unwrap()));
202 assert!(!set.contains("192.168.3.16".parse().unwrap()));
203
204 k9::snapshot!(
208 &set,
209 r#"
210{
211 "10.0.3.0/24",
212 "10.0.4.0/24",
213 "10.0.7.0/24",
214 "127.0.0.1",
215 "192.168.1.0/24",
216 "192.168.3.0/28",
217 "::1",
218}
219"#
220 );
221 }
222}