mta_sts/
policy.rs

1use futures::future::BoxFuture;
2use std::collections::BTreeMap;
3
4#[derive(Debug, PartialEq, Eq)]
5pub enum PolicyMode {
6    Enforce,
7    Testing,
8    None,
9}
10
11#[derive(Debug)]
12pub struct MtaStsPolicy {
13    pub mode: PolicyMode,
14    pub mx: Vec<String>,
15    pub max_age: u64,
16    pub fields: BTreeMap<String, Vec<String>>,
17}
18
19impl MtaStsPolicy {
20    pub fn parse(data: &str) -> anyhow::Result<Self> {
21        let mut fields: BTreeMap<String, Vec<String>> = BTreeMap::new();
22
23        for line in data.lines() {
24            if line.is_empty() {
25                continue;
26            }
27            let (key, value) = line
28                .split_once(':')
29                .ok_or_else(|| anyhow::anyhow!("invalid STS policy {data}"))?;
30            let key = key.trim();
31            let value = value.trim();
32
33            fields
34                .entry(key.to_string())
35                .or_default()
36                .push(value.to_string());
37        }
38
39        let version = fields
40            .remove("version")
41            .ok_or_else(|| anyhow::anyhow!("STS policy {data} is missing a version"))?;
42        if version.len() != 1 || version[0] != "STSv1" {
43            anyhow::bail!("STS policy {data} has incompatible STS version");
44        }
45
46        let mode = match fields.remove("mode") {
47            None => anyhow::bail!("STS policy {data} is missing required mode"),
48            Some(mode) if mode.len() == 1 => match mode[0].as_str() {
49                "enforce" => PolicyMode::Enforce,
50                "testing" => PolicyMode::Testing,
51                "none" => PolicyMode::None,
52                _ => anyhow::bail!("STS policy {data} has invalid mode"),
53            },
54            _ => anyhow::bail!("STS policy {data} has invalid mode"),
55        };
56
57        let mut mx = match fields.remove("mx") {
58            None if mode == PolicyMode::None => vec![],
59            None => anyhow::bail!("STS policy {data} is missing required mx"),
60            Some(v) => v,
61        };
62
63        // Ensure that the mx entries are lowercased to aid
64        // the mx_name_matches method
65        mx.iter_mut()
66            .for_each(|entry| *entry = entry.to_lowercase());
67
68        let max_age: u64 = match fields.remove("max_age") {
69            None => anyhow::bail!("STS policy {data} is missing required max_age"),
70            Some(v) if v.len() == 1 => {
71                let max_age = &v[0];
72                max_age.parse().map_err(|err| anyhow::anyhow!("STS policy {data} has max_age {max_age} that is not a valid integer: {err:#}"))?
73            }
74            _ => anyhow::bail!("STS policy {data} has invalid max_age"),
75        };
76
77        Ok(Self {
78            fields,
79            mode,
80            mx,
81            max_age,
82        })
83    }
84
85    /// Returns true if `name` matches any of the allowed mx
86    /// host name patterns.
87    /// `name` must be lowercase.
88    pub fn mx_name_matches(&self, name: &str) -> bool {
89        for pattern in &self.mx {
90            if name_match(name, pattern) {
91                return true;
92            }
93        }
94        false
95    }
96}
97
98fn name_match(name: &str, pattern: &str) -> bool {
99    // kumo uses canonicalized names that include a trailing period.
100    // remove that from the name when matching against a pattern.
101    let name = name.trim_end_matches('.');
102
103    if pattern.starts_with("*.") {
104        let suffix = &pattern[1..];
105        if let Some(lhs) = name.strip_suffix(suffix) {
106            // Wildcards only match the first component
107            return lhs.find('.').is_none();
108        }
109        false
110    } else {
111        name == pattern
112    }
113}
114
115pub trait Get: Sync + Send {
116    fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>>;
117}
118
119pub async fn load_policy_for_domain(
120    policy_domain: &str,
121    getter: &dyn Get,
122) -> anyhow::Result<MtaStsPolicy> {
123    // kumo uses canonicalized names that include a trailing period.
124    // remove that from the policy_domain to avoid extraneous redirects
125    // on some remote servers
126    let policy_domain = policy_domain.trim_end_matches('.');
127    let url = format!("https://mta-sts.{policy_domain}/.well-known/mta-sts.txt");
128    let policy = getter.http_get(&url).await?;
129    MtaStsPolicy::parse(&policy)
130}
131
132#[cfg(test)]
133pub(crate) mod test {
134    use super::*;
135
136    pub struct TestGetter {
137        policies: BTreeMap<&'static str, &'static str>,
138    }
139
140    impl TestGetter {
141        pub fn new<I: IntoIterator<Item = (&'static str, &'static str)>>(iter: I) -> Self {
142            Self {
143                policies: BTreeMap::from_iter(iter),
144            }
145        }
146    }
147
148    impl Get for TestGetter {
149        fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
150            Box::pin(async move {
151                match self.policies.get(url) {
152                    Some(result) => Ok(result.to_string()),
153                    None => anyhow::bail!("404 {url}"),
154                }
155            })
156        }
157    }
158
159    const SAMPLE_POLICY: &str =
160        "version: STSv1 \nmode: enforce\nmx: mail.example.com\r\nmx:\t*.example.net\nmx: backupmx.example.com\nmax_age: 604800";
161
162    #[tokio::test]
163    async fn get_policy() {
164        let getter = TestGetter::new([(
165            "https://mta-sts.example.com/.well-known/mta-sts.txt",
166            SAMPLE_POLICY,
167        )]);
168
169        k9::snapshot!(
170            load_policy_for_domain("example.com.", &getter)
171                .await
172                .unwrap(),
173            r#"
174MtaStsPolicy {
175    mode: Enforce,
176    mx: [
177        "mail.example.com",
178        "*.example.net",
179        "backupmx.example.com",
180    ],
181    max_age: 604800,
182    fields: {},
183}
184"#
185        );
186    }
187
188    #[test]
189    fn parse_policy() {
190        k9::snapshot!(
191            MtaStsPolicy::parse(SAMPLE_POLICY).unwrap(),
192            r#"
193MtaStsPolicy {
194    mode: Enforce,
195    mx: [
196        "mail.example.com",
197        "*.example.net",
198        "backupmx.example.com",
199    ],
200    max_age: 604800,
201    fields: {},
202}
203"#
204        );
205    }
206
207    #[test]
208    fn name_matching() {
209        assert!(name_match("foo.com", "foo.com"));
210        assert!(name_match("foo.com.", "foo.com"));
211        assert!(!name_match("bar.com", "foo.com"));
212        assert!(name_match("foo.com", "*.com"));
213        assert!(name_match("mx.example.com", "*.example.com"));
214        assert!(!name_match("not.mx.example.com", "*.example.com"));
215        assert!(!name_match("example.com", "*.example.com"));
216    }
217}