1use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::time::Duration;
7
8#[derive(Copy, Clone, Eq, Hash, PartialEq)]
11pub struct Wrap<T>(T);
12
13impl<T> Wrap<T> {
14 pub fn into_inner(self) -> T {
15 self.0
16 }
17}
18
19pub fn serialize<T, S>(d: &T, s: S) -> Result<S::Ok, S::Error>
20where
21 for<'a> Wrap<&'a T>: Serialize,
22 S: Serializer,
23{
24 Wrap(d).serialize(s)
25}
26
27pub fn deserialize<'a, T, D>(d: D) -> Result<T, D::Error>
28where
29 Wrap<T>: Deserialize<'a>,
30 D: Deserializer<'a>,
31{
32 Wrap::deserialize(d).map(|w| w.0)
33}
34
35impl<'de> Deserialize<'de> for Wrap<Duration> {
36 fn deserialize<D>(d: D) -> Result<Wrap<Duration>, D::Error>
37 where
38 D: Deserializer<'de>,
39 {
40 struct V;
41
42 impl serde::de::Visitor<'_> for V {
43 type Value = Duration;
44
45 fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
46 fmt.write_str("a duration")
47 }
48
49 fn visit_f64<E>(self, v: f64) -> Result<Duration, E>
50 where
51 E: serde::de::Error,
52 {
53 Ok(Duration::from_secs_f64(v))
54 }
55
56 fn visit_u64<E>(self, v: u64) -> Result<Duration, E>
57 where
58 E: serde::de::Error,
59 {
60 Ok(Duration::from_secs(v))
61 }
62
63 fn visit_i64<E>(self, v: i64) -> Result<Duration, E>
64 where
65 E: serde::de::Error,
66 {
67 match v.try_into() {
68 Ok(secs) => Ok(Duration::from_secs(secs)),
69 Err(err) => Err(E::custom(format!(
70 "duration must either be a string or a \
71 positive integer specifying the number of seconds. \
72 (error: {err:#})"
73 ))),
74 }
75 }
76
77 fn visit_str<E>(self, v: &str) -> Result<Duration, E>
78 where
79 E: serde::de::Error,
80 {
81 humantime::parse_duration(v)
82 .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))
83 }
84 }
85
86 d.deserialize_any(V).map(Wrap)
87 }
88}
89
90impl<'de> Deserialize<'de> for Wrap<Option<Duration>> {
91 fn deserialize<D>(d: D) -> Result<Wrap<Option<Duration>>, D::Error>
92 where
93 D: Deserializer<'de>,
94 {
95 match Option::<Wrap<Duration>>::deserialize(d)? {
96 Some(Wrap(dur)) => Ok(Wrap(Some(dur))),
97 None => Ok(Wrap(None)),
98 }
99 }
100}
101
102impl Serialize for Wrap<&Duration> {
103 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104 where
105 S: Serializer,
106 {
107 humantime::format_duration(*self.0)
108 .to_string()
109 .serialize(serializer)
110 }
111}
112
113impl Serialize for Wrap<Duration> {
114 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
115 where
116 S: Serializer,
117 {
118 humantime::format_duration(self.0)
119 .to_string()
120 .serialize(serializer)
121 }
122}
123
124impl Serialize for Wrap<&Option<Duration>> {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: Serializer,
128 {
129 match *self.0 {
130 Some(dur) => serializer.serialize_some(&Wrap(dur)),
131 None => serializer.serialize_none(),
132 }
133 }
134}
135
136impl Serialize for Wrap<Option<Duration>> {
137 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138 where
139 S: Serializer,
140 {
141 Wrap(&self.0).serialize(serializer)
142 }
143}
144
145#[cfg(test)]
146mod test {
147 use super::*;
148
149 #[test]
150 fn simple_string() {
151 #[derive(Serialize, Deserialize)]
152 struct Foo {
153 #[serde(with = "super")]
154 time: Duration,
155 }
156
157 let json = r#"{"time": "15 seconds"}"#;
158 let foo = serde_json::from_str::<Foo>(json).unwrap();
159 assert_eq!(foo.time, Duration::from_secs(15));
160 let reverse = serde_json::to_string(&foo).unwrap();
161 assert_eq!(reverse, r#"{"time":"15s"}"#);
162 }
163
164 #[test]
165 fn simple_int() {
166 #[derive(Serialize, Deserialize)]
167 struct Foo {
168 #[serde(with = "super")]
169 time: Duration,
170 }
171
172 let json = r#"{"time": 15}"#;
173 let foo = serde_json::from_str::<Foo>(json).unwrap();
174 assert_eq!(foo.time, Duration::from_secs(15));
175 }
176
177 #[test]
178 fn simple_float() {
179 #[derive(Serialize, Deserialize)]
180 struct Foo {
181 #[serde(with = "super")]
182 time: Duration,
183 }
184
185 let json = r#"{"time": 15.0}"#;
186 let foo = serde_json::from_str::<Foo>(json).unwrap();
187 assert_eq!(foo.time, Duration::from_secs(15));
188 }
189}