mta_sts/
lib.rs

1use dns_resolver::Resolver;
2use futures::future::BoxFuture;
3use hickory_resolver::Name;
4use policy::MtaStsPolicy;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8lruttl::declare_cache! {
9/// Caches MTA-STS policy information by domain
10static CACHE: LruCacheWithTtl<Name, CachedPolicy>::new("mta_sts_policy", 64 * 1024);
11}
12
13pub mod dns;
14pub mod policy;
15
16#[derive(Clone, Debug)]
17struct CachedPolicy {
18    pub id: String,
19    pub policy: Arc<MtaStsPolicy>,
20}
21
22struct Getter {}
23
24impl policy::Get for Getter {
25    fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
26        Box::pin(async move {
27            let response = reqwest::Client::builder()
28                // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
29                // HTTP 3xx redirects MUST NOT be followed
30                .redirect(reqwest::redirect::Policy::none())
31                .timeout(std::time::Duration::from_secs(20))
32                .build()?
33                .request(reqwest::Method::GET, url)
34                .send()
35                .await?;
36
37            // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
38            // Policies fetched via HTTPS are only valid if the HTTP
39            // response code is 200 (OK)
40            let status = response.status();
41            if status != reqwest::StatusCode::OK {
42                anyhow::bail!("failed to GET {url}: {status}");
43            }
44
45            // <https://datatracker.ietf.org/doc/html/rfc8461#section-3.2>
46            // senders SHOULD validate that the media type is "text/plain"
47            // to guard against cases where web servers allow untrusted users
48            // to host non-text content.
49            // We need to do some manual grubbing about for this, as reqwest's
50            // Response::text() method doesn't verify that the type is textual,
51            // just whether it decodes as text, which is precisely what we're
52            // trying to guard against.
53
54            let content_type = response
55                .headers()
56                .get(reqwest::header::CONTENT_TYPE)
57                .ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
58
59            let content_type = content_type.to_str()?;
60
61            let ct = if let Some((ct, _)) = content_type.split_once(';') {
62                ct.trim()
63            } else {
64                content_type.trim()
65            };
66            if ct != "text/plain" {
67                anyhow::bail!("Content-Type must be text/plain, got {content_type}");
68            }
69
70            Ok(response.text().await?)
71        })
72    }
73}
74
75pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
76    let resolver = dns_resolver::get_resolver();
77    get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
78}
79
80fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
81    CACHE.get(name)
82}
83
84async fn get_policy_for_domain_impl(
85    policy_domain: &str,
86    resolver: &dyn Resolver,
87    getter: &dyn policy::Get,
88) -> anyhow::Result<Arc<MtaStsPolicy>> {
89    let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
90
91    if let Some(cached) = cache_lookup(&name) {
92        // Removal of the DNS record does not invalidate our
93        // cached result, only updating it with a different id
94        let still_valid = dns::resolve_dns_record(policy_domain, resolver)
95            .await
96            .map(|r| cached.id == r.id)
97            .unwrap_or(true);
98
99        if still_valid {
100            return Ok(Arc::clone(&cached.policy));
101        }
102    }
103
104    let record = dns::resolve_dns_record(policy_domain, resolver).await?;
105
106    let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
107
108    let expires = Instant::now() + Duration::from_secs(policy.max_age);
109
110    CACHE
111        .insert(
112            name,
113            CachedPolicy {
114                id: record.id,
115                policy: Arc::clone(&policy),
116            },
117            expires.into(),
118        )
119        .await;
120
121    Ok(policy)
122}
123
124/*
125#[cfg(test)]
126mod test {
127    use super::*;
128
129    #[tokio::test]
130    async fn get_gmail_policy() {
131        k9::snapshot!(
132            get_policy_for_domain("gmail.com").await.unwrap(),
133            r#"
134MtaStsPolicy {
135    mode: Enforce,
136    mx: [
137        "gmail-smtp-in.l.google.com",
138        "*.gmail-smtp-in.l.google.com",
139    ],
140    max_age: 86400,
141    fields: {},
142}
143"#
144        );
145    }
146}
147*/