duration_serde/
lib.rs

1//! Based on <https://github.com/jean-airoldie/humantime-serde>
2//! which is made available under the terms of the Apache 2.0 License.
3//! This implementation allows for deserializing from integer and
4//! floating point values; they are assumed to represent seconds.
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::time::Duration;
7
8/// A wrapper type which implements `Serialize` and `Deserialize` for
9/// `Duration`
10#[derive(Copy, Clone, Eq, Hash, PartialEq)]
11pub struct Wrap<T>(T);
12
13impl<T> Wrap<T> {
14    pub fn into_inner(self) -> T {
15        self.0
16    }
17}
18
19pub fn serialize<T, S>(d: &T, s: S) -> Result<S::Ok, S::Error>
20where
21    for<'a> Wrap<&'a T>: Serialize,
22    S: Serializer,
23{
24    Wrap(d).serialize(s)
25}
26
27pub fn deserialize<'a, T, D>(d: D) -> Result<T, D::Error>
28where
29    Wrap<T>: Deserialize<'a>,
30    D: Deserializer<'a>,
31{
32    Wrap::deserialize(d).map(|w| w.0)
33}
34
35impl<'de> Deserialize<'de> for Wrap<Duration> {
36    fn deserialize<D>(d: D) -> Result<Wrap<Duration>, D::Error>
37    where
38        D: Deserializer<'de>,
39    {
40        struct V;
41
42        impl serde::de::Visitor<'_> for V {
43            type Value = Duration;
44
45            fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
46                fmt.write_str("a duration")
47            }
48
49            fn visit_f64<E>(self, v: f64) -> Result<Duration, E>
50            where
51                E: serde::de::Error,
52            {
53                Ok(Duration::from_secs_f64(v))
54            }
55
56            fn visit_u64<E>(self, v: u64) -> Result<Duration, E>
57            where
58                E: serde::de::Error,
59            {
60                Ok(Duration::from_secs(v))
61            }
62
63            fn visit_i64<E>(self, v: i64) -> Result<Duration, E>
64            where
65                E: serde::de::Error,
66            {
67                match v.try_into() {
68                    Ok(secs) => Ok(Duration::from_secs(secs)),
69                    Err(err) => Err(E::custom(format!(
70                        "duration must either be a string or a \
71                         positive integer specifying the number of seconds. \
72                         (error: {err:#})"
73                    ))),
74                }
75            }
76
77            fn visit_str<E>(self, v: &str) -> Result<Duration, E>
78            where
79                E: serde::de::Error,
80            {
81                humantime::parse_duration(v)
82                    .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))
83            }
84        }
85
86        d.deserialize_any(V).map(Wrap)
87    }
88}
89
90impl<'de> Deserialize<'de> for Wrap<Option<Duration>> {
91    fn deserialize<D>(d: D) -> Result<Wrap<Option<Duration>>, D::Error>
92    where
93        D: Deserializer<'de>,
94    {
95        match Option::<Wrap<Duration>>::deserialize(d)? {
96            Some(Wrap(dur)) => Ok(Wrap(Some(dur))),
97            None => Ok(Wrap(None)),
98        }
99    }
100}
101
102impl Serialize for Wrap<&Duration> {
103    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104    where
105        S: Serializer,
106    {
107        humantime::format_duration(*self.0)
108            .to_string()
109            .serialize(serializer)
110    }
111}
112
113impl Serialize for Wrap<Duration> {
114    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
115    where
116        S: Serializer,
117    {
118        humantime::format_duration(self.0)
119            .to_string()
120            .serialize(serializer)
121    }
122}
123
124impl Serialize for Wrap<&Option<Duration>> {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: Serializer,
128    {
129        match *self.0 {
130            Some(dur) => serializer.serialize_some(&Wrap(dur)),
131            None => serializer.serialize_none(),
132        }
133    }
134}
135
136impl Serialize for Wrap<Option<Duration>> {
137    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138    where
139        S: Serializer,
140    {
141        Wrap(&self.0).serialize(serializer)
142    }
143}
144
145#[cfg(test)]
146mod test {
147    use super::*;
148
149    #[test]
150    fn simple_string() {
151        #[derive(Serialize, Deserialize)]
152        struct Foo {
153            #[serde(with = "super")]
154            time: Duration,
155        }
156
157        let json = r#"{"time": "15 seconds"}"#;
158        let foo = serde_json::from_str::<Foo>(json).unwrap();
159        assert_eq!(foo.time, Duration::from_secs(15));
160        let reverse = serde_json::to_string(&foo).unwrap();
161        assert_eq!(reverse, r#"{"time":"15s"}"#);
162    }
163
164    #[test]
165    fn simple_int() {
166        #[derive(Serialize, Deserialize)]
167        struct Foo {
168            #[serde(with = "super")]
169            time: Duration,
170        }
171
172        let json = r#"{"time": 15}"#;
173        let foo = serde_json::from_str::<Foo>(json).unwrap();
174        assert_eq!(foo.time, Duration::from_secs(15));
175    }
176
177    #[test]
178    fn simple_float() {
179        #[derive(Serialize, Deserialize)]
180        struct Foo {
181            #[serde(with = "super")]
182            time: Duration,
183        }
184
185        let json = r#"{"time": 15.0}"#;
186        let foo = serde_json::from_str::<Foo>(json).unwrap();
187        assert_eq!(foo.time, Duration::from_secs(15));
188    }
189}