1use futures::future::BoxFuture;
2use std::collections::BTreeMap;
3
4#[derive(Debug, PartialEq, Eq)]
5pub enum PolicyMode {
6 Enforce,
7 Testing,
8 None,
9}
10
11#[derive(Debug)]
12pub struct MtaStsPolicy {
13 pub mode: PolicyMode,
14 pub mx: Vec<String>,
15 pub max_age: u64,
16 pub fields: BTreeMap<String, Vec<String>>,
17}
18
19impl MtaStsPolicy {
20 pub fn parse(data: &str) -> anyhow::Result<Self> {
21 let mut fields: BTreeMap<String, Vec<String>> = BTreeMap::new();
22
23 for line in data.lines() {
24 if line.is_empty() {
25 continue;
26 }
27 let (key, value) = line
28 .split_once(':')
29 .ok_or_else(|| anyhow::anyhow!("invalid STS policy {data}"))?;
30 let key = key.trim();
31 let value = value.trim();
32
33 fields
34 .entry(key.to_string())
35 .or_default()
36 .push(value.to_string());
37 }
38
39 let version = fields
40 .remove("version")
41 .ok_or_else(|| anyhow::anyhow!("STS policy {data} is missing a version"))?;
42 if version.len() != 1 || version[0] != "STSv1" {
43 anyhow::bail!("STS policy {data} has incompatible STS version");
44 }
45
46 let mode = match fields.remove("mode") {
47 None => anyhow::bail!("STS policy {data} is missing required mode"),
48 Some(mode) if mode.len() == 1 => match mode[0].as_str() {
49 "enforce" => PolicyMode::Enforce,
50 "testing" => PolicyMode::Testing,
51 "none" => PolicyMode::None,
52 _ => anyhow::bail!("STS policy {data} has invalid mode"),
53 },
54 _ => anyhow::bail!("STS policy {data} has invalid mode"),
55 };
56
57 let mut mx = match fields.remove("mx") {
58 None if mode == PolicyMode::None => vec![],
59 None => anyhow::bail!("STS policy {data} is missing required mx"),
60 Some(v) => v,
61 };
62
63 mx.iter_mut()
66 .for_each(|entry| *entry = entry.to_lowercase());
67
68 let max_age: u64 = match fields.remove("max_age") {
69 None => anyhow::bail!("STS policy {data} is missing required max_age"),
70 Some(v) if v.len() == 1 => {
71 let max_age = &v[0];
72 max_age.parse().map_err(|err| anyhow::anyhow!("STS policy {data} has max_age {max_age} that is not a valid integer: {err:#}"))?
73 }
74 _ => anyhow::bail!("STS policy {data} has invalid max_age"),
75 };
76
77 Ok(Self {
78 fields,
79 mode,
80 mx,
81 max_age,
82 })
83 }
84
85 pub fn mx_name_matches(&self, name: &str) -> bool {
89 for pattern in &self.mx {
90 if name_match(name, pattern) {
91 return true;
92 }
93 }
94 false
95 }
96}
97
98fn name_match(name: &str, pattern: &str) -> bool {
99 let name = name.trim_end_matches('.');
102
103 if pattern.starts_with("*.") {
104 let suffix = &pattern[1..];
105 if let Some(lhs) = name.strip_suffix(suffix) {
106 return lhs.find('.').is_none();
108 }
109 false
110 } else {
111 name == pattern
112 }
113}
114
115pub trait Get: Sync + Send {
116 fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>>;
117}
118
119pub async fn load_policy_for_domain(
120 policy_domain: &str,
121 getter: &dyn Get,
122) -> anyhow::Result<MtaStsPolicy> {
123 let policy_domain = policy_domain.trim_end_matches('.');
127 let url = format!("https://mta-sts.{policy_domain}/.well-known/mta-sts.txt");
128 let policy = getter.http_get(&url).await?;
129 MtaStsPolicy::parse(&policy)
130}
131
132#[cfg(test)]
133pub(crate) mod test {
134 use super::*;
135
136 pub struct TestGetter {
137 policies: BTreeMap<&'static str, &'static str>,
138 }
139
140 impl TestGetter {
141 pub fn new<I: IntoIterator<Item = (&'static str, &'static str)>>(iter: I) -> Self {
142 Self {
143 policies: BTreeMap::from_iter(iter),
144 }
145 }
146 }
147
148 impl Get for TestGetter {
149 fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
150 Box::pin(async move {
151 match self.policies.get(url) {
152 Some(result) => Ok(result.to_string()),
153 None => anyhow::bail!("404 {url}"),
154 }
155 })
156 }
157 }
158
159 const SAMPLE_POLICY: &str =
160 "version: STSv1 \nmode: enforce\nmx: mail.example.com\r\nmx:\t*.example.net\nmx: backupmx.example.com\nmax_age: 604800";
161
162 #[tokio::test]
163 async fn get_policy() {
164 let getter = TestGetter::new([(
165 "https://mta-sts.example.com/.well-known/mta-sts.txt",
166 SAMPLE_POLICY,
167 )]);
168
169 k9::snapshot!(
170 load_policy_for_domain("example.com.", &getter)
171 .await
172 .unwrap(),
173 r#"
174MtaStsPolicy {
175 mode: Enforce,
176 mx: [
177 "mail.example.com",
178 "*.example.net",
179 "backupmx.example.com",
180 ],
181 max_age: 604800,
182 fields: {},
183}
184"#
185 );
186 }
187
188 #[test]
189 fn parse_policy() {
190 k9::snapshot!(
191 MtaStsPolicy::parse(SAMPLE_POLICY).unwrap(),
192 r#"
193MtaStsPolicy {
194 mode: Enforce,
195 mx: [
196 "mail.example.com",
197 "*.example.net",
198 "backupmx.example.com",
199 ],
200 max_age: 604800,
201 fields: {},
202}
203"#
204 );
205 }
206
207 #[test]
208 fn name_matching() {
209 assert!(name_match("foo.com", "foo.com"));
210 assert!(name_match("foo.com.", "foo.com"));
211 assert!(!name_match("bar.com", "foo.com"));
212 assert!(name_match("foo.com", "*.com"));
213 assert!(name_match("mx.example.com", "*.example.com"));
214 assert!(!name_match("not.mx.example.com", "*.example.com"));
215 assert!(!name_match("example.com", "*.example.com"));
216 }
217}