use dns_resolver::Resolver;
use futures::future::BoxFuture;
use hickory_resolver::Name;
use lruttl::LruCacheWithTtl;
use policy::MtaStsPolicy;
use std::sync::{Arc, LazyLock, Mutex};
use std::time::{Duration, Instant};
static CACHE: LazyLock<Mutex<LruCacheWithTtl<Name, CachedPolicy>>> =
LazyLock::new(|| Mutex::new(LruCacheWithTtl::new_named("mta_sts_policy", 64 * 1024)));
pub mod dns;
pub mod policy;
#[derive(Clone)]
struct CachedPolicy {
pub id: String,
pub policy: Arc<MtaStsPolicy>,
}
struct Getter {}
impl policy::Get for Getter {
fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
Box::pin(async move {
let response = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(20))
.build()?
.request(reqwest::Method::GET, url)
.send()
.await?;
let status = response.status();
if status != reqwest::StatusCode::OK {
anyhow::bail!("failed to GET {url}: {status}");
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
let content_type = content_type.to_str()?;
let ct = if let Some((ct, _)) = content_type.split_once(';') {
ct.trim()
} else {
content_type.trim()
};
if ct != "text/plain" {
anyhow::bail!("Content-Type must be text/plain, got {content_type}");
}
Ok(response.text().await?)
})
}
}
pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
let resolver = dns_resolver::get_resolver();
get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
}
fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
CACHE.lock().unwrap().get(&name).map(|p| p.clone())
}
async fn get_policy_for_domain_impl(
policy_domain: &str,
resolver: &dyn Resolver,
getter: &dyn policy::Get,
) -> anyhow::Result<Arc<MtaStsPolicy>> {
let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
if let Some(cached) = cache_lookup(&name) {
let still_valid = dns::resolve_dns_record(policy_domain, resolver)
.await
.map(|r| cached.id == r.id)
.unwrap_or(true);
if still_valid {
return Ok(Arc::clone(&cached.policy));
}
}
let record = dns::resolve_dns_record(policy_domain, resolver).await?;
let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
let expires = Instant::now() + Duration::from_secs(policy.max_age);
CACHE.lock().unwrap().insert(
name,
CachedPolicy {
id: record.id,
policy: Arc::clone(&policy),
},
expires,
);
Ok(policy)
}