kumo_address/
socket.rs

1use crate::host::{AddressParseError, HostAddress};
2use serde::{Deserialize, Serialize};
3use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
4use std::os::unix::net::SocketAddr as UnixSocketAddr;
5use std::path::Path;
6use std::str::FromStr;
7
8#[derive(Clone, Serialize, Deserialize)]
9#[serde(try_from = "String", into = "String")]
10pub enum SocketAddress {
11    UnixDomain(Box<UnixSocketAddr>),
12    V4(std::net::SocketAddrV4),
13    V6(std::net::SocketAddrV6),
14}
15
16impl std::fmt::Debug for SocketAddress {
17    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
18        <Self as std::fmt::Display>::fmt(self, fmt)
19    }
20}
21
22impl std::fmt::Display for SocketAddress {
23    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
24        match self {
25            Self::UnixDomain(unix) => match unix.as_pathname() {
26                Some(path) => path.display().fmt(fmt),
27                None => write!(fmt, "<unbound unix domain>"),
28            },
29            Self::V4(a) => a.fmt(fmt),
30            Self::V6(a) => a.fmt(fmt),
31        }
32    }
33}
34
35impl From<SocketAddress> for String {
36    fn from(a: SocketAddress) -> String {
37        format!("{a}")
38    }
39}
40
41impl TryFrom<String> for SocketAddress {
42    type Error = AddressParseError;
43    fn try_from(s: String) -> Result<SocketAddress, Self::Error> {
44        SocketAddress::from_str(&s)
45    }
46}
47
48impl SocketAddress {
49    /// Returns the "host" portion of the address
50    pub fn host(&self) -> HostAddress {
51        match self {
52            Self::UnixDomain(p) => HostAddress::UnixDomain(p.clone()),
53            Self::V4(a) => HostAddress::V4(*a.ip()),
54            Self::V6(a) => HostAddress::V6(*a.ip()),
55        }
56    }
57
58    /// Returns the unix domain socket representation of the address
59    pub fn unix(&self) -> Option<UnixSocketAddr> {
60        match self {
61            Self::V4(_) | Self::V6(_) => None,
62            Self::UnixDomain(unix) => Some((**unix).clone()),
63        }
64    }
65
66    /// Returns the ip representation of the address
67    pub fn ip(&self) -> Option<SocketAddr> {
68        match self {
69            Self::V4(a) => Some((*a).into()),
70            Self::V6(a) => Some((*a).into()),
71            Self::UnixDomain(_) => None,
72        }
73    }
74}
75
76impl FromStr for SocketAddress {
77    type Err = AddressParseError;
78    fn from_str(s: &str) -> Result<SocketAddress, Self::Err> {
79        // At the time of writing, Rust's IPv6 SockAddr parsing
80        // interally only accepts `[address]:port` while its IPv4
81        // SockAddr parsing only accepts `address:port`.
82        // In the email world, `[]` is used to indicate a literal
83        // IP address so we desire the ability to uniformly use
84        // the `[]` syntax in both cases, so we check for that
85        // first and parse the internal address out.
86
87        if s.starts_with('[') {
88            if let Some(host_end) = s.find(']') {
89                let (host, remainder) = s.split_at(host_end);
90                let host = &host[1..];
91
92                if let Some(port) = remainder.strip_prefix("]:") {
93                    if let Ok(port) = port.parse::<u16>() {
94                        match HostAddress::from_str(host) {
95                            Ok(HostAddress::V4(a)) => {
96                                return Ok(SocketAddress::V4(SocketAddrV4::new(a, port)))
97                            }
98                            Ok(HostAddress::V6(a)) => {
99                                return Ok(SocketAddress::V6(SocketAddrV6::new(a, port, 0, 0)))
100                            }
101
102                            _ => {}
103                        }
104                    }
105                }
106            }
107        }
108
109        match SocketAddr::from_str(s) {
110            Ok(a) => Ok(a.into()),
111            Err(net_err) => {
112                let path: &Path = s.as_ref();
113                if path.is_relative() {
114                    Err(AddressParseError {
115                        candidate: s.to_string(),
116                        net_err,
117                        unix_err: std::io::Error::new(
118                            std::io::ErrorKind::Other,
119                            "unix domain path must be absolute",
120                        ),
121                    })
122                } else {
123                    match UnixSocketAddr::from_pathname(path) {
124                        Ok(unix) => Ok(SocketAddress::UnixDomain(unix.into())),
125                        Err(unix_err) => Err(AddressParseError {
126                            candidate: s.to_string(),
127                            net_err,
128                            unix_err,
129                        }),
130                    }
131                }
132            }
133        }
134    }
135}
136
137impl PartialEq for SocketAddress {
138    fn eq(&self, other: &Self) -> bool {
139        match (self, other) {
140            (Self::UnixDomain(a), Self::UnixDomain(b)) => {
141                match (a.as_pathname(), b.as_pathname()) {
142                    (Some(a), Some(b)) => a.eq(b),
143                    (None, None) => true,
144                    _ => false,
145                }
146            }
147            (Self::V4(a), Self::V4(b)) => a.eq(b),
148            (Self::V6(a), Self::V6(b)) => a.eq(b),
149            _ => false,
150        }
151    }
152}
153
154impl Eq for SocketAddress {}
155
156impl From<UnixSocketAddr> for SocketAddress {
157    fn from(unix: UnixSocketAddr) -> SocketAddress {
158        SocketAddress::UnixDomain(unix.into())
159    }
160}
161
162impl From<SocketAddr> for SocketAddress {
163    fn from(ip: SocketAddr) -> SocketAddress {
164        match ip {
165            SocketAddr::V4(a) => SocketAddress::V4(a),
166            SocketAddr::V6(a) => SocketAddress::V6(a),
167        }
168    }
169}
170
171impl From<tokio::net::unix::SocketAddr> for SocketAddress {
172    fn from(unix: tokio::net::unix::SocketAddr) -> SocketAddress {
173        let unix: UnixSocketAddr = unix.into();
174        unix.into()
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use super::*;
181    use std::net::{Ipv4Addr, Ipv6Addr};
182
183    #[test]
184    fn parse() {
185        assert_eq!(
186            "10.0.0.1:25".parse::<SocketAddress>(),
187            Ok(SocketAddress::V4(SocketAddrV4::new(
188                Ipv4Addr::new(10, 0, 0, 1),
189                25
190            )))
191        );
192        assert_eq!(
193            "[10.0.0.1]:25".parse::<SocketAddress>(),
194            Ok(SocketAddress::V4(SocketAddrV4::new(
195                Ipv4Addr::new(10, 0, 0, 1),
196                25
197            )))
198        );
199        assert_eq!(
200            "[::1]:100".parse::<SocketAddress>(),
201            Ok(SocketAddress::V6(SocketAddrV6::new(
202                Ipv6Addr::LOCALHOST,
203                100,
204                0,
205                0
206            )))
207        );
208        assert_eq!(
209            "/some/path".parse::<SocketAddress>(),
210            Ok(SocketAddress::UnixDomain(
211                UnixSocketAddr::from_pathname("/some/path").unwrap().into()
212            ))
213        );
214        assert_eq!(
215            format!("{:#}", "hello there".parse::<SocketAddress>().unwrap_err()),
216            "Failed to parse 'hello there' as an address. \
217            Got 'invalid socket address syntax' when considering it as \
218            an IP address and 'unix domain path must be absolute' \
219            when considering it as a unix domain socket path."
220        );
221        assert_eq!(
222            format!("{:#}", "[10.0.0.1]".parse::<SocketAddress>().unwrap_err()),
223            "Failed to parse '[10.0.0.1]' as an address. \
224            Got 'invalid socket address syntax' when considering it as \
225            an IP address and 'unix domain path must be absolute' \
226            when considering it as a unix domain socket path."
227        );
228    }
229}