mta_sts/
lib.rs

1use dns_resolver::Resolver;
2use futures::future::BoxFuture;
3use hickory_resolver::Name;
4use policy::MtaStsPolicy;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8lruttl::declare_cache! {
9static CACHE: LruCacheWithTtl<Name, CachedPolicy>::new("mta_sts_policy", 64 * 1024);
10}
11
12pub mod dns;
13pub mod policy;
14
15#[derive(Clone, Debug)]
16struct CachedPolicy {
17    pub id: String,
18    pub policy: Arc<MtaStsPolicy>,
19}
20
21struct Getter {}
22
23impl policy::Get for Getter {
24    fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
25        Box::pin(async move {
26            let response = reqwest::Client::builder()
27                // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
28                // HTTP 3xx redirects MUST NOT be followed
29                .redirect(reqwest::redirect::Policy::none())
30                .timeout(std::time::Duration::from_secs(20))
31                .build()?
32                .request(reqwest::Method::GET, url)
33                .send()
34                .await?;
35
36            // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
37            // Policies fetched via HTTPS are only valid if the HTTP
38            // response code is 200 (OK)
39            let status = response.status();
40            if status != reqwest::StatusCode::OK {
41                anyhow::bail!("failed to GET {url}: {status}");
42            }
43
44            // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.2>
45            // senders SHOULD validate that the media type is "text/plain"
46            // to guard against cases where web servers allow untrusted users
47            // to host non-text content.
48            // We need to do some manual grubbing about for this, as reqwest's
49            // Response::text() method doesn't verify that the type is textual,
50            // just whether it decodes as text, which is precisely what we're
51            // trying to guard against.
52
53            let content_type = response
54                .headers()
55                .get(reqwest::header::CONTENT_TYPE)
56                .ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
57
58            let content_type = content_type.to_str()?;
59
60            let ct = if let Some((ct, _)) = content_type.split_once(';') {
61                ct.trim()
62            } else {
63                content_type.trim()
64            };
65            if ct != "text/plain" {
66                anyhow::bail!("Content-Type must be text/plain, got {content_type}");
67            }
68
69            Ok(response.text().await?)
70        })
71    }
72}
73
74pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
75    let resolver = dns_resolver::get_resolver();
76    get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
77}
78
79fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
80    CACHE.get(name)
81}
82
83async fn get_policy_for_domain_impl(
84    policy_domain: &str,
85    resolver: &dyn Resolver,
86    getter: &dyn policy::Get,
87) -> anyhow::Result<Arc<MtaStsPolicy>> {
88    let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
89
90    if let Some(cached) = cache_lookup(&name) {
91        // Removal of the DNS record does not invalidate our
92        // cached result, only updating it with a different id
93        let still_valid = dns::resolve_dns_record(policy_domain, resolver)
94            .await
95            .map(|r| cached.id == r.id)
96            .unwrap_or(true);
97
98        if still_valid {
99            return Ok(Arc::clone(&cached.policy));
100        }
101    }
102
103    let record = dns::resolve_dns_record(policy_domain, resolver).await?;
104
105    let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
106
107    let expires = Instant::now() + Duration::from_secs(policy.max_age);
108
109    CACHE
110        .insert(
111            name,
112            CachedPolicy {
113                id: record.id,
114                policy: Arc::clone(&policy),
115            },
116            expires.into(),
117        )
118        .await;
119
120    Ok(policy)
121}
122
123/*
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    #[tokio::test]
129    async fn get_gmail_policy() {
130        k9::snapshot!(
131            get_policy_for_domain("gmail.com").await.unwrap(),
132            r#"
133MtaStsPolicy {
134    mode: Enforce,
135    mx: [
136        "gmail-smtp-in.l.google.com",
137        "*.gmail-smtp-in.l.google.com",
138    ],
139    max_age: 86400,
140    fields: {},
141}
142"#
143        );
144    }
145}
146*/