1use crate::host::{AddressParseError, HostAddress};
2use serde::{Deserialize, Serialize};
3use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
4use std::os::unix::net::SocketAddr as UnixSocketAddr;
5use std::path::Path;
6use std::str::FromStr;
7
8#[derive(Clone, Serialize, Deserialize)]
9#[serde(try_from = "String", into = "String")]
10pub enum SocketAddress {
11 UnixDomain(Box<UnixSocketAddr>),
12 V4(std::net::SocketAddrV4),
13 V6(std::net::SocketAddrV6),
14}
15
16impl std::fmt::Debug for SocketAddress {
17 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
18 <Self as std::fmt::Display>::fmt(self, fmt)
19 }
20}
21
22impl std::fmt::Display for SocketAddress {
23 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
24 match self {
25 Self::UnixDomain(unix) => match unix.as_pathname() {
26 Some(path) => path.display().fmt(fmt),
27 None => write!(fmt, "<unbound unix domain>"),
28 },
29 Self::V4(a) => a.fmt(fmt),
30 Self::V6(a) => a.fmt(fmt),
31 }
32 }
33}
34
35impl From<SocketAddress> for String {
36 fn from(a: SocketAddress) -> String {
37 format!("{a}")
38 }
39}
40
41impl TryFrom<String> for SocketAddress {
42 type Error = AddressParseError;
43 fn try_from(s: String) -> Result<SocketAddress, Self::Error> {
44 SocketAddress::from_str(&s)
45 }
46}
47
48impl SocketAddress {
49 pub fn host(&self) -> HostAddress {
51 match self {
52 Self::UnixDomain(p) => HostAddress::UnixDomain(p.clone()),
53 Self::V4(a) => HostAddress::V4(*a.ip()),
54 Self::V6(a) => HostAddress::V6(*a.ip()),
55 }
56 }
57
58 pub fn unix(&self) -> Option<UnixSocketAddr> {
60 match self {
61 Self::V4(_) | Self::V6(_) => None,
62 Self::UnixDomain(unix) => Some((**unix).clone()),
63 }
64 }
65
66 pub fn ip(&self) -> Option<SocketAddr> {
68 match self {
69 Self::V4(a) => Some((*a).into()),
70 Self::V6(a) => Some((*a).into()),
71 Self::UnixDomain(_) => None,
72 }
73 }
74}
75
76impl FromStr for SocketAddress {
77 type Err = AddressParseError;
78 fn from_str(s: &str) -> Result<SocketAddress, Self::Err> {
79 if s.starts_with('[') {
88 if let Some(host_end) = s.find(']') {
89 let (host, remainder) = s.split_at(host_end);
90 let host = &host[1..];
91
92 if let Some(port) = remainder.strip_prefix("]:") {
93 if let Ok(port) = port.parse::<u16>() {
94 match HostAddress::from_str(host) {
95 Ok(HostAddress::V4(a)) => {
96 return Ok(SocketAddress::V4(SocketAddrV4::new(a, port)))
97 }
98 Ok(HostAddress::V6(a)) => {
99 return Ok(SocketAddress::V6(SocketAddrV6::new(a, port, 0, 0)))
100 }
101
102 _ => {}
103 }
104 }
105 }
106 }
107 }
108
109 match SocketAddr::from_str(s) {
110 Ok(a) => Ok(a.into()),
111 Err(net_err) => {
112 let path: &Path = s.as_ref();
113 if path.is_relative() {
114 Err(AddressParseError {
115 candidate: s.to_string(),
116 net_err,
117 unix_err: std::io::Error::new(
118 std::io::ErrorKind::Other,
119 "unix domain path must be absolute",
120 ),
121 })
122 } else {
123 match UnixSocketAddr::from_pathname(path) {
124 Ok(unix) => Ok(SocketAddress::UnixDomain(unix.into())),
125 Err(unix_err) => Err(AddressParseError {
126 candidate: s.to_string(),
127 net_err,
128 unix_err,
129 }),
130 }
131 }
132 }
133 }
134 }
135}
136
137impl PartialEq for SocketAddress {
138 fn eq(&self, other: &Self) -> bool {
139 match (self, other) {
140 (Self::UnixDomain(a), Self::UnixDomain(b)) => {
141 match (a.as_pathname(), b.as_pathname()) {
142 (Some(a), Some(b)) => a.eq(b),
143 (None, None) => true,
144 _ => false,
145 }
146 }
147 (Self::V4(a), Self::V4(b)) => a.eq(b),
148 (Self::V6(a), Self::V6(b)) => a.eq(b),
149 _ => false,
150 }
151 }
152}
153
154impl Eq for SocketAddress {}
155
156impl From<UnixSocketAddr> for SocketAddress {
157 fn from(unix: UnixSocketAddr) -> SocketAddress {
158 SocketAddress::UnixDomain(unix.into())
159 }
160}
161
162impl From<SocketAddr> for SocketAddress {
163 fn from(ip: SocketAddr) -> SocketAddress {
164 match ip {
165 SocketAddr::V4(a) => SocketAddress::V4(a),
166 SocketAddr::V6(a) => SocketAddress::V6(a),
167 }
168 }
169}
170
171impl From<tokio::net::unix::SocketAddr> for SocketAddress {
172 fn from(unix: tokio::net::unix::SocketAddr) -> SocketAddress {
173 let unix: UnixSocketAddr = unix.into();
174 unix.into()
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use super::*;
181 use std::net::{Ipv4Addr, Ipv6Addr};
182
183 #[test]
184 fn parse() {
185 assert_eq!(
186 "10.0.0.1:25".parse::<SocketAddress>(),
187 Ok(SocketAddress::V4(SocketAddrV4::new(
188 Ipv4Addr::new(10, 0, 0, 1),
189 25
190 )))
191 );
192 assert_eq!(
193 "[10.0.0.1]:25".parse::<SocketAddress>(),
194 Ok(SocketAddress::V4(SocketAddrV4::new(
195 Ipv4Addr::new(10, 0, 0, 1),
196 25
197 )))
198 );
199 assert_eq!(
200 "[::1]:100".parse::<SocketAddress>(),
201 Ok(SocketAddress::V6(SocketAddrV6::new(
202 Ipv6Addr::LOCALHOST,
203 100,
204 0,
205 0
206 )))
207 );
208 assert_eq!(
209 "/some/path".parse::<SocketAddress>(),
210 Ok(SocketAddress::UnixDomain(
211 UnixSocketAddr::from_pathname("/some/path").unwrap().into()
212 ))
213 );
214 assert_eq!(
215 format!("{:#}", "hello there".parse::<SocketAddress>().unwrap_err()),
216 "Failed to parse 'hello there' as an address. \
217 Got 'invalid socket address syntax' when considering it as \
218 an IP address and 'unix domain path must be absolute' \
219 when considering it as a unix domain socket path."
220 );
221 assert_eq!(
222 format!("{:#}", "[10.0.0.1]".parse::<SocketAddress>().unwrap_err()),
223 "Failed to parse '[10.0.0.1]' as an address. \
224 Got 'invalid socket address syntax' when considering it as \
225 an IP address and 'unix domain path must be absolute' \
226 when considering it as a unix domain socket path."
227 );
228 }
229}