use bitstring::BitString;
pub use cidr::{AnyIpCidr, IpCidr};
#[cfg(feature = "lua")]
use config::{any_err, get_or_create_sub_module};
#[cfg(feature = "lua")]
use mlua::prelude::LuaUserData;
#[cfg(feature = "lua")]
use mlua::{FromLua, Lua, MetaMethod, UserDataMethods};
#[cfg(feature = "lua")]
use mod_memoize::CacheValue;
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Debug, Clone, PartialEq)]
pub struct CidrMap<V>
where
V: Clone,
{
root: Option<Node<V>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Node<V>
where
V: Clone,
{
InnerNode(InnerNode<V>),
Leaf(Leaf<V>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct Leaf<V>
where
V: Clone,
{
pub key: AnyIpCidr,
pub value: V,
}
#[derive(Clone, Debug, PartialEq)]
pub struct InnerNode<V>
where
V: Clone,
{
key: AnyIpCidr,
children: Box<Children<V>>,
}
#[derive(Clone, Debug, PartialEq)]
struct Children<V>
where
V: Clone,
{
left: Node<V>,
right: Node<V>,
}
impl<V> InnerNode<V>
where
V: Clone,
{
pub fn key(&self) -> &AnyIpCidr {
&self.key
}
pub fn pick_side<'a>(&'a self, subkey: &AnyIpCidr) -> &'a Node<V> {
if subkey.get(self.key.len()) {
&self.children.right
} else {
&self.children.left
}
}
pub fn pick_side_mut<'a>(&'a mut self, subkey: &AnyIpCidr) -> &'a mut Node<V> {
if subkey.get(self.key.len()) {
&mut self.children.right
} else {
&mut self.children.left
}
}
pub fn left(&self) -> &Node<V> {
&self.children.left
}
pub fn right(&self) -> &Node<V> {
&self.children.right
}
}
impl<V> Node<V>
where
V: Clone,
{
fn new_leaf(key: AnyIpCidr, value: V) -> Self {
Self::Leaf(Leaf { key, value })
}
fn new_children_unknown_order(
shared_prefix_len: usize,
a: Node<V>,
b: Node<V>,
) -> Box<Children<V>> {
let a_right = a.key().get(shared_prefix_len);
assert_eq!(!a_right, b.key().get(shared_prefix_len));
if a_right {
Box::new(Children { left: b, right: a })
} else {
Box::new(Children { left: a, right: b })
}
}
fn new_inner_unknown_order(shared_prefix_len: usize, a: Node<V>, b: Node<V>) -> Node<V> {
let mut key = a.key().clone();
key.clip(shared_prefix_len);
Node::InnerNode(InnerNode {
key,
children: Self::new_children_unknown_order(shared_prefix_len, a, b),
})
}
pub fn key(&self) -> &AnyIpCidr {
match *self {
Node::Leaf(ref leaf) => &leaf.key,
Node::InnerNode(ref inner) => &inner.key,
}
}
fn leaf_ref(&self) -> Option<&Leaf<V>> {
match *self {
Node::Leaf(ref leaf) => Some(leaf),
_ => None,
}
}
fn convert_leaf(&mut self, key_len: usize, value: V) {
*self = match self {
Node::Leaf(leaf) => {
let mut leaf = leaf.clone();
leaf.key.clip(key_len);
leaf.value = value;
Node::Leaf(leaf)
}
Node::InnerNode(inner) => {
let mut key = inner.key;
key.clip(key_len);
Self::new_leaf(key, value)
}
};
}
fn insert_uncompressed(&mut self, key: AnyIpCidr, value: V)
where
V: Clone,
{
let (self_key_len, shared_prefix_len) = {
let key_ref = self.key();
(key_ref.len(), key_ref.shared_prefix_len(&key))
};
if shared_prefix_len == key.len() {
self.convert_leaf(shared_prefix_len, value);
} else if shared_prefix_len < self_key_len {
debug_assert!(shared_prefix_len < key.len());
*self = Self::new_inner_unknown_order(
shared_prefix_len,
self.clone(),
Self::new_leaf(key, value),
);
} else {
debug_assert!(shared_prefix_len == self_key_len);
debug_assert!(shared_prefix_len < key.len());
match *self {
Node::Leaf(_) => {
let old_value = self.leaf_ref().unwrap().value.clone();
let mut new_node = Self::new_leaf(key.clone(), value);
for l in (shared_prefix_len..key.len()).rev() {
let mut other_key = key.clone();
other_key.clip(l + 1);
other_key.flip(l);
new_node = Self::new_inner_unknown_order(
l,
new_node,
Self::new_leaf(other_key, old_value.clone()),
);
}
*self = new_node;
}
Node::InnerNode(ref mut inner) => {
inner.pick_side_mut(&key).insert_uncompressed(key, value);
}
}
}
}
fn insert(&mut self, key: AnyIpCidr, value: V)
where
V: Clone + PartialEq,
{
let (self_key_len, shared_prefix_len) = {
let key_ref = self.key();
(key_ref.len(), key_ref.shared_prefix_len(&key))
};
if shared_prefix_len == key.len() {
self.convert_leaf(shared_prefix_len, value);
} else if shared_prefix_len < self_key_len {
debug_assert!(shared_prefix_len < key.len());
if shared_prefix_len + 1 == self_key_len && shared_prefix_len + 1 == key.len() {
if let Node::Leaf(ref mut this) = *self {
if this.value == value {
this.key.clip(shared_prefix_len);
return; }
}
}
*self = Self::new_inner_unknown_order(
shared_prefix_len,
self.clone(),
Self::new_leaf(key, value),
);
} else {
debug_assert!(shared_prefix_len == self_key_len);
debug_assert!(shared_prefix_len < key.len());
match *self {
Node::Leaf(_) => {
let new_node = {
let old_value = &self.leaf_ref().unwrap().value;
if *old_value == value {
return;
}
let mut new_node = Self::new_leaf(key.clone(), value);
for l in (shared_prefix_len..key.len()).rev() {
let mut other_key = key.clone();
other_key.clip(l + 1);
other_key.flip(l);
new_node = Self::new_inner_unknown_order(
l,
new_node,
Self::new_leaf(other_key, old_value.clone()),
);
}
new_node
};
*self = new_node;
return;
}
Node::InnerNode(ref mut inner) => {
inner.pick_side_mut(&key).insert(key, value);
}
}
self.compress();
}
}
fn compress(&mut self)
where
V: PartialEq,
{
let self_key_len = self.key().len();
let compress = match *self {
Node::InnerNode(ref inner) => {
let left_value = match inner.children.left {
Node::Leaf(ref leaf) if leaf.key.len() == self_key_len + 1 => &leaf.value,
_ => return, };
let right_value = match inner.children.right {
Node::Leaf(ref leaf) if leaf.key.len() == self_key_len + 1 => &leaf.value,
_ => return, };
left_value == right_value
}
Node::Leaf(_) => return, };
if compress {
*self = match self {
Node::InnerNode(inner) => match &inner.children.left {
Node::Leaf(leaf) => Node::Leaf(Leaf {
key: inner.key.clone(),
value: leaf.value.clone(),
}),
_ => unreachable!(),
},
_ => unreachable!(),
};
}
}
}
impl<V> Default for CidrMap<V>
where
V: Clone,
{
fn default() -> Self {
Self { root: None }
}
}
impl<V> CidrMap<V>
where
V: Clone,
{
pub fn new() -> Self {
Self { root: None }
}
pub fn contains(&self, ip: IpAddr) -> bool {
self.get_prefix_match(ip).is_some()
}
pub fn get_prefix_match(&self, ip: IpAddr) -> Option<&V> {
let key: AnyIpCidr = IpCidr::new_host(ip).into();
self.get_prefix_match_cidr(&key)
}
pub fn get_prefix_match_cidr(&self, key: &AnyIpCidr) -> Option<&V> {
let node = self.root.as_ref()?;
Self::find_item(node, &key)
}
fn find_item<'a>(node: &'a Node<V>, ip: &AnyIpCidr) -> Option<&'a V> {
match node {
Node::Leaf(leaf) => {
if leaf.key.contains(&ip.first_address().unwrap()) {
Some(&leaf.value)
} else {
None
}
}
Node::InnerNode(inner) => Self::find_item(inner.pick_side(&ip), ip),
}
}
pub fn insert_uncompressed(&mut self, key: AnyIpCidr, value: V)
where
V: Clone,
{
match self.root {
None => {
self.root = Some(Node::new_leaf(key, value));
}
Some(ref mut node) => {
node.insert_uncompressed(key, value);
}
}
}
pub fn insert(&mut self, key: AnyIpCidr, value: V)
where
V: Clone + PartialEq,
{
match self.root {
None => {
self.root = Some(Node::new_leaf(key, value));
}
Some(ref mut node) => {
node.insert(key, value);
}
}
}
pub fn root(&self) -> Option<&Node<V>> {
self.root.as_ref()
}
pub fn iter(&self) -> Iter<V> {
Iter::new(self)
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Direction {
Left,
Right,
Up,
}
pub struct Iter<'a, V: 'a>
where
V: Clone,
{
stack: Vec<(Direction, &'a Node<V>)>,
}
impl<'a, V> Iter<'a, V>
where
V: Clone,
{
pub fn new(tree: &'a CidrMap<V>) -> Self {
match tree.root() {
None => Iter { stack: Vec::new() },
Some(node) => Iter {
stack: vec![(Direction::Left, node)],
},
}
}
}
impl<'a, V> Iterator for Iter<'a, V>
where
V: Clone,
{
type Item = (&'a AnyIpCidr, &'a V);
fn next(&mut self) -> Option<Self::Item> {
if self.stack.is_empty() {
return None;
}
while Direction::Up == self.stack[self.stack.len() - 1].0 {
if 1 == self.stack.len() {
self.stack.clear();
return None;
}
self.stack.pop();
debug_assert!(!self.stack.is_empty());
}
loop {
let top = self.stack.len() - 1;
let (dir, node) = self.stack[top];
debug_assert!(!self.stack.is_empty());
match dir {
Direction::Left => match *node {
Node::InnerNode(ref inner) => {
self.stack[top].0 = Direction::Right;
self.stack.push((Direction::Left, inner.left()));
}
Node::Leaf(ref leaf) => {
self.stack[top].0 = Direction::Up;
return Some((&leaf.key, &leaf.value));
}
},
Direction::Right => match *node {
Node::InnerNode(ref inner) => {
self.stack[top].0 = Direction::Up;
self.stack.push((Direction::Left, inner.right()));
}
Node::Leaf(_) => unreachable!(),
},
Direction::Up => unreachable!(),
}
}
}
}
impl<S, V: Clone + Eq> FromIterator<(S, V)> for CidrMap<V>
where
S: Into<AnyIpCidr>,
{
fn from_iter<I: IntoIterator<Item = (S, V)>>(iter: I) -> Self {
let mut map = CidrMap::new();
for (key, value) in iter {
map.insert(key.into(), value);
}
map
}
}
impl<T: Ord + Into<AnyIpCidr>, const N: usize, V: Clone + Ord> From<[(T, V); N]> for CidrMap<V> {
fn from(mut arr: [(T, V); N]) -> Self {
if N == 0 {
return CidrMap::new();
}
arr.sort();
let iter = IntoIterator::into_iter(arr).map(|k| k);
iter.collect()
}
}
impl<V: Clone> Into<Vec<(AnyIpCidr, V)>> for CidrMap<V> {
fn into(self) -> Vec<(AnyIpCidr, V)> {
let mut result = vec![];
for (key, value) in self.iter() {
result.push((key.clone(), value.clone()));
}
result
}
}
#[cfg(feature = "lua")]
impl LuaUserData for CidrMap<CacheValue> {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
mod_memoize::Memoized::impl_memoize(methods);
methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
let key = parse_cidr_from_ip_and_or_port(&key).map_err(any_err)?;
if let Some(value) = this.get_prefix_match_cidr(&key) {
let value = value.as_lua(lua)?;
Ok(Some(value))
} else {
Ok(None)
}
});
methods.add_meta_method_mut(
MetaMethod::NewIndex,
|lua, this, (key, value): (String, mlua::Value)| {
let key = parse_cidr_from_ip_and_or_port(&key).map_err(any_err)?;
let value = CacheValue::from_lua(value, lua)?;
this.insert(key, value);
Ok(())
},
);
}
}
#[cfg(feature = "lua")]
fn parse_cidr_from_ip_and_or_port(s: &str) -> anyhow::Result<AnyIpCidr> {
match parse_cidr(s) {
Ok(c) => Ok(c),
Err(err) => {
if s.starts_with('[') {
if let Some((ip, _port)) = s[1..].split_once(']') {
return parse_cidr(ip).map_err(|err| {
anyhow::anyhow!(
"failed to parse '{ip}', the \
[]-enclosed portion of '{s}', as an IP address: {err:#}"
)
});
}
}
if let Some((ip, _port)) = s.rsplit_once(':') {
return parse_cidr(ip).map_err(|err| {
anyhow::anyhow!(
"failed to parse '{ip}', the \
:-delimited portion of '{s}', as an IP address: {err:#}"
)
});
}
Err(err)
}
}
}
pub fn parse_cidr(s: &str) -> anyhow::Result<AnyIpCidr> {
AnyIpCidr::from_str(s).map_err(|err| {
match cidr::parsers::parse_any_cidr_full_ignore_hostbits(
s,
std::str::FromStr::from_str,
std::str::FromStr::from_str,
) {
Ok(loose) => {
anyhow::anyhow!("{s} is not a valid CIDR: {err:#}. Did you mean {loose}?")
}
Err(err) => {
anyhow::anyhow!("{s} is not a valid CIDR: {err:#}")
}
}
})
}
#[cfg(feature = "lua")]
pub fn register(lua: &Lua) -> anyhow::Result<()> {
use std::collections::HashMap;
let cidr_mod = get_or_create_sub_module(lua, "cidr")?;
cidr_mod.set(
"make_map",
lua.create_function(|lua, value: Option<HashMap<String, mlua::Value>>| {
let mut cmap: CidrMap<mod_memoize::CacheValue> = CidrMap::new();
if let Some(value) = value {
for (k, v) in value {
let k = parse_cidr_from_ip_and_or_port(&k).map_err(any_err)?;
let v = CacheValue::from_lua(v, lua)?;
cmap.insert(k, v);
}
}
Ok(cmap)
})?,
)?;
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_parse_error_message() {
assert_eq!(
parse_cidr("10.0.0.1/24").unwrap_err().to_string(),
"10.0.0.1/24 is not a valid CIDR: host part of address was not zero. Did you mean 10.0.0.0/24?"
);
}
#[test]
fn cidrmap() {
let set: CidrMap<&str> = [
(parse_cidr("127.0.0.1").unwrap(), "loopbackv4"),
(parse_cidr("::1").unwrap(), "loopbackv6"),
(parse_cidr("192.168.1.0/24").unwrap(), ".1"),
(parse_cidr("192.168.1.24").unwrap(), ".1"),
(parse_cidr("192.168.3.0/28").unwrap(), ".3"),
(parse_cidr("192.168.3.2").unwrap(), ".3.split"),
(parse_cidr("10.0.3.0/24").unwrap(), "10.3"),
(parse_cidr("10.0.4.0/24").unwrap(), "10.4"),
(parse_cidr("10.0.7.0/24").unwrap(), "10.7"),
]
.into();
fn get<'a>(set: &'a CidrMap<&str>, key: &str) -> Option<&'a str> {
let key = key.parse().unwrap();
set.get_prefix_match(key).copied()
}
assert_eq!(get(&set, "127.0.0.1"), Some("loopbackv4"));
assert_eq!(get(&set, "127.0.0.2"), None);
assert_eq!(get(&set, "::1"), Some("loopbackv6"));
assert_eq!(get(&set, "192.168.2.1"), None);
assert_eq!(get(&set, "192.168.1.0"), Some(".1"));
assert_eq!(get(&set, "192.168.1.1"), Some(".1"));
assert_eq!(get(&set, "192.168.1.100"), Some(".1"));
assert_eq!(get(&set, "192.168.1.24"), Some(".1"));
assert_eq!(get(&set, "192.168.3.0"), Some(".3"));
assert_eq!(get(&set, "192.168.3.16"), None);
assert_eq!(get(&set, "192.168.3.2"), Some(".3.split"));
let decompose: Vec<(AnyIpCidr, &str)> = set.into();
k9::snapshot!(
decompose,
r#"
[
(
V4(
10.0.3.0/24,
),
"10.3",
),
(
V4(
10.0.4.0/24,
),
"10.4",
),
(
V4(
10.0.7.0/24,
),
"10.7",
),
(
V4(
127.0.0.1/32,
),
"loopbackv4",
),
(
V4(
192.168.1.0/24,
),
".1",
),
(
V4(
192.168.3.0/31,
),
".3",
),
(
V4(
192.168.3.2/32,
),
".3.split",
),
(
V4(
192.168.3.3/32,
),
".3",
),
(
V4(
192.168.3.4/30,
),
".3",
),
(
V4(
192.168.3.8/29,
),
".3",
),
(
V6(
::1/128,
),
"loopbackv6",
),
]
"#
);
}
}