1use crate::map::CidrMap;
2use crate::{parse_cidr, Node};
3pub use cidr::AnyIpCidr;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::net::IpAddr;
7
8#[derive(Serialize, Deserialize, Clone, Default, PartialEq)]
9#[serde(try_from = "Vec<String>", into = "Vec<String>")]
10pub struct CidrSet(CidrMap<()>);
11
12impl CidrSet {
13 pub fn new() -> Self {
14 Self::default()
15 }
16
17 pub fn default_trusted_hosts() -> Self {
18 vec!["127.0.0.1", "::1"].try_into().unwrap()
19 }
20
21 pub fn default_prohibited_hosts() -> Self {
22 vec!["127.0.0.0/8", "::1"].try_into().unwrap()
23 }
24
25 pub fn contains(&self, ip: IpAddr) -> bool {
26 self.0.contains(ip)
27 }
28
29 pub fn insert<T: Ord + Into<AnyIpCidr>>(&mut self, value: T) {
30 self.0.insert(value.into(), ());
31 }
32}
33
34impl Debug for CidrSet {
35 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
36 use std::fmt::DebugSet;
37 let mut set = fmt.debug_set();
38
39 fn walk(node: &Node<()>, set: &mut DebugSet) {
40 match node {
41 Node::InnerNode(inner) => {
42 walk(&inner.children.left, set);
43 walk(&inner.children.right, set);
44 }
45 Node::Leaf(leaf) => {
46 set.entry(&leaf.key.to_string());
47 }
48 }
49 }
50
51 if let Some(root) = &self.0.root() {
52 walk(root, &mut set);
53 }
54
55 set.finish()
56 }
57}
58
59impl<T: Ord + Into<AnyIpCidr>, const N: usize> From<[T; N]> for CidrSet {
60 fn from(mut arr: [T; N]) -> Self {
62 if N == 0 {
63 return CidrSet::new();
64 }
65
66 arr.sort();
68 let iter = IntoIterator::into_iter(arr); iter.collect()
70 }
71}
72
73impl<S> FromIterator<S> for CidrSet
74where
75 S: Into<AnyIpCidr>,
76{
77 fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
78 let mut set = CidrMap::new();
79 for entry in iter {
80 set.insert(entry.into(), ());
81 }
82 Self(set)
83 }
84}
85
86impl TryFrom<Vec<&str>> for CidrSet {
87 type Error = String;
88
89 fn try_from(v: Vec<&str>) -> Result<Self, String> {
90 let mut set = CidrMap::new();
91 let mut problems = vec![];
92 for entry in v {
93 match parse_cidr(entry) {
94 Ok(cidr) => {
95 set.insert(cidr, ());
96 }
97 Err(err) => {
98 problems.push(format!("{entry}: {err:#}"));
99 }
100 }
101 }
102 if problems.is_empty() {
103 Ok(Self(set))
104 } else {
105 Err(problems.join(", "))
106 }
107 }
108}
109
110impl TryFrom<Vec<String>> for CidrSet {
111 type Error = String;
112
113 fn try_from(v: Vec<std::string::String>) -> Result<Self, String> {
114 let mut set = CidrMap::new();
115 let mut problems = vec![];
116 for entry in v {
117 match parse_cidr(&entry) {
118 Ok(cidr) => {
119 set.insert(cidr, ());
120 }
121 Err(err) => {
122 problems.push(format!("{entry}: {err:#}"));
123 }
124 }
125 }
126 if problems.is_empty() {
127 Ok(Self(set))
128 } else {
129 Err(problems.join(", "))
130 }
131 }
132}
133
134impl From<CidrSet> for Vec<String> {
135 fn from(val: CidrSet) -> Self {
136 let mut result = vec![];
137 for (key, _unit) in val.0.iter() {
138 result.push(key.to_string());
139 }
140 result
141 }
142}
143
144impl From<Vec<AnyIpCidr>> for CidrSet {
145 fn from(entries: Vec<AnyIpCidr>) -> Self {
146 entries.into_iter().collect()
147 }
148}
149
150impl From<CidrSet> for Vec<AnyIpCidr> {
151 fn from(val: CidrSet) -> Self {
152 let mut result = vec![];
153 for (key, _unit) in val.0.iter() {
154 result.push(*key);
155 }
156 result
157 }
158}
159
160#[cfg(test)]
161mod test {
162 use super::*;
163
164 #[test]
165 fn cidrset_any() {
166 let empty_set = CidrSet::new();
167 let set_with_any: CidrSet = [AnyIpCidr::Any].into();
168
169 assert!(!empty_set.contains("127.0.0.1".parse().unwrap()));
170 assert!(set_with_any.contains("127.0.0.1".parse().unwrap()));
171 }
172
173 #[test]
174 fn cidrset() {
175 let set: CidrSet = [
176 parse_cidr("127.0.0.1").unwrap(),
177 parse_cidr("::1").unwrap(),
178 parse_cidr("192.168.1.0/24").unwrap(),
179 parse_cidr("192.168.1.24").unwrap(),
181 parse_cidr("192.168.3.0/28").unwrap(),
182 parse_cidr("10.0.3.0/24").unwrap(),
183 parse_cidr("10.0.4.0/24").unwrap(),
184 parse_cidr("10.0.7.0/24").unwrap(),
185 ]
186 .into();
187
188 assert!(set.contains("127.0.0.1".parse().unwrap()));
189 assert!(!set.contains("127.0.0.2".parse().unwrap()));
190 assert!(set.contains("::1".parse().unwrap()));
191
192 assert!(!set.contains("192.168.2.1".parse().unwrap()));
193
194 assert!(set.contains("192.168.1.0".parse().unwrap()));
195 assert!(set.contains("192.168.1.1".parse().unwrap()));
196 assert!(set.contains("192.168.1.100".parse().unwrap()));
197 assert!(set.contains("192.168.1.24".parse().unwrap()));
198
199 assert!(set.contains("192.168.3.0".parse().unwrap()));
200 assert!(!set.contains("192.168.3.16".parse().unwrap()));
201
202 k9::snapshot!(
206 &set,
207 r#"
208{
209 "10.0.3.0/24",
210 "10.0.4.0/24",
211 "10.0.7.0/24",
212 "127.0.0.1",
213 "192.168.1.0/24",
214 "192.168.3.0/28",
215 "::1",
216}
217"#
218 );
219 }
220}