mta_sts/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
use dns_resolver::Resolver;
use futures::future::BoxFuture;
use hickory_resolver::Name;
use policy::MtaStsPolicy;
use std::sync::Arc;
use std::time::{Duration, Instant};
lruttl::declare_cache! {
static CACHE: LruCacheWithTtl<Name, CachedPolicy>::new("mta_sts_policy", 64 * 1024);
}
pub mod dns;
pub mod policy;
#[derive(Clone, Debug)]
struct CachedPolicy {
pub id: String,
pub policy: Arc<MtaStsPolicy>,
}
struct Getter {}
impl policy::Get for Getter {
fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
Box::pin(async move {
let response = reqwest::Client::builder()
// <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
// HTTP 3xx redirects MUST NOT be followed
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(20))
.build()?
.request(reqwest::Method::GET, url)
.send()
.await?;
// <https://datatracker.ietf.org/doc/html/rfc8461#section-3.3>
// Policies fetched via HTTPS are only valid if the HTTP
// response code is 200 (OK)
let status = response.status();
if status != reqwest::StatusCode::OK {
anyhow::bail!("failed to GET {url}: {status}");
}
// <https://datatracker.ietf.org/doc/html/rfc8461#section-3.2>
// senders SHOULD validate that the media type is "text/plain"
// to guard against cases where web servers allow untrusted users
// to host non-text content.
// We need to do some manual grubbing about for this, as reqwest's
// Response::text() method doesn't verify that the type is textual,
// just whether it decodes as text, which is precisely what we're
// trying to guard against.
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
let content_type = content_type.to_str()?;
let ct = if let Some((ct, _)) = content_type.split_once(';') {
ct.trim()
} else {
content_type.trim()
};
if ct != "text/plain" {
anyhow::bail!("Content-Type must be text/plain, got {content_type}");
}
Ok(response.text().await?)
})
}
}
pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
let resolver = dns_resolver::get_resolver();
get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
}
async fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
CACHE.get(&name).await
}
async fn get_policy_for_domain_impl(
policy_domain: &str,
resolver: &dyn Resolver,
getter: &dyn policy::Get,
) -> anyhow::Result<Arc<MtaStsPolicy>> {
let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
if let Some(cached) = cache_lookup(&name).await {
// Removal of the DNS record does not invalidate our
// cached result, only updating it with a different id
let still_valid = dns::resolve_dns_record(policy_domain, resolver)
.await
.map(|r| cached.id == r.id)
.unwrap_or(true);
if still_valid {
return Ok(Arc::clone(&cached.policy));
}
}
let record = dns::resolve_dns_record(policy_domain, resolver).await?;
let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
let expires = Instant::now() + Duration::from_secs(policy.max_age);
CACHE
.insert(
name,
CachedPolicy {
id: record.id,
policy: Arc::clone(&policy),
},
expires.into(),
)
.await;
Ok(policy)
}
/*
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn get_gmail_policy() {
k9::snapshot!(
get_policy_for_domain("gmail.com").await.unwrap(),
r#"
MtaStsPolicy {
mode: Enforce,
mx: [
"gmail-smtp-in.l.google.com",
"*.gmail-smtp-in.l.google.com",
],
max_age: 86400,
fields: {},
}
"#
);
}
}
*/