kumo_api_types/
shaping.rs

1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39    type Error = fancy_regex::Error;
40
41    fn try_from(s: String) -> fancy_regex::Result<Self> {
42        Ok(Self(fancy_regex::Regex::new(&s)?))
43    }
44}
45
46impl From<Regex> for String {
47    fn from(r: Regex) -> String {
48        r.0.as_str().to_string()
49    }
50}
51
52impl std::ops::Deref for Regex {
53    type Target = fancy_regex::Regex;
54    fn deref(&self) -> &Self::Target {
55        &self.0
56    }
57}
58
59impl std::hash::Hash for Regex {
60    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61        self.0.as_str().hash(hasher)
62    }
63}
64
65/// toml::Value is not Hash because it may contain floating
66/// point numbers, which are problematic from a Ord and Eq
67/// perspective. We're okay with skirting around that for
68/// our purposes here, so we implement our own hashable
69/// wrapper around the toml value.
70#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73    value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77    fn from(value: toml::Value) -> Self {
78        Self { value }
79    }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83    fn from(value: HashableTomlValue) -> toml::Value {
84        value.value
85    }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89    type Target = toml::Value;
90    fn deref(&self) -> &toml::Value {
91        &self.value
92    }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97    H: Hasher,
98{
99    match value {
100        toml::Value::Boolean(v) => v.hash(h),
101        toml::Value::Datetime(v) => {
102            if let Some(d) = &v.date {
103                d.year.hash(h);
104                d.month.hash(h);
105                d.day.hash(h);
106            }
107            if let Some(t) = &v.time {
108                t.hour.hash(h);
109                t.minute.hash(h);
110                t.second.hash(h);
111                t.nanosecond.hash(h);
112            }
113            if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114                minutes.hash(h);
115            }
116        }
117        toml::Value::String(v) => v.hash(h),
118        toml::Value::Integer(v) => v.hash(h),
119        toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120        toml::Value::Array(a) => {
121            for v in a.iter() {
122                hash_toml(v, h);
123            }
124        }
125        toml::Value::Table(m) => {
126            for (k, v) in m.iter() {
127                k.hash(h);
128                hash_toml(v, h);
129            }
130        }
131    }
132}
133
134impl Hash for HashableTomlValue {
135    fn hash<H>(&self, h: &mut H)
136    where
137        H: Hasher,
138    {
139        hash_toml(&self.value, h);
140    }
141}
142
143/// Represents an individual EgressPathConfig field name and value.
144/// It only allows deserializing from valid EgressPathConfig field + values.
145#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147    try_from = "EgressPathConfigValueUnchecked",
148    into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151    pub name: String,
152    pub value: HashableTomlValue,
153}
154
155/// This is the type that we actually use to deserialize EgressPathConfigValue items.
156/// It doesn't care about validity; it is used solely to tell serde what shape of
157/// data to expect.
158/// The validation is performed by the TryFrom impl that is used to convert to the
159/// checked form below.
160#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162    pub name: String,
163    pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167    type Error = anyhow::Error;
168    fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169        let mut map = toml::map::Map::new();
170        map.insert(config.name.clone(), config.value.clone());
171        let table = toml::Value::Table(map);
172
173        // Attempt to deserialize as EgressPathConfig.
174        // If it fails, then the field name/value are invalid
175        EgressPathConfig::deserialize(table)?;
176
177        // If we reach this point, we can pass along the name/value
178        Ok(EgressPathConfigValue {
179            name: config.name,
180            value: HashableTomlValue {
181                value: config.value,
182            },
183        })
184    }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188    fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189        EgressPathConfigValueUnchecked {
190            name: config.name,
191            value: config.value.value,
192        }
193    }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198    Suspend,
199    SetConfig(EgressPathConfigValue),
200    SuspendTenant,
201    SuspendCampaign,
202    SetDomainConfig(EgressPathConfigValue),
203    Bounce,
204    BounceTenant,
205    BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210    /// Trigger on the first match, immediately
211    #[default]
212    Immediate,
213    /// Trigger when a certain number of matches occur
214    /// over a certain time period.
215    Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221    #[serde(deserialize_with = "regex_string_or_array")]
222    pub regex: Vec<Regex>,
223
224    #[serde(deserialize_with = "one_or_many_action")]
225    pub action: Vec<Action>,
226
227    #[serde(default)]
228    pub trigger: Trigger,
229
230    #[serde(with = "duration_serde")]
231    pub duration: Duration,
232
233    #[serde(skip)]
234    pub was_rollup: bool,
235
236    /// if true, this rule can match kumomta internally generated
237    /// Response messages, otherwise, the rule will skip testing
238    /// against those.
239    #[serde(default)]
240    pub match_internal: bool,
241}
242
243impl Rule {
244    pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245        if is_internal && !self.match_internal {
246            return false;
247        }
248        self.regex
249            .iter()
250            .any(|r| r.is_match(response).unwrap_or(false))
251    }
252
253    pub fn clone_and_set_rollup(&self) -> Self {
254        let mut result = self.clone();
255        result.was_rollup = true;
256        result
257    }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263    by_site: OrderMap<String, PartialEntry>,
264    by_domain: OrderMap<String, PartialEntry>,
265    by_provider: OrderMap<String, ProviderEntry>,
266    warnings: Vec<String>,
267    errors: Vec<String>,
268    hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273    pub async fn get_egress_path_config(
274        &self,
275        domain: &str,
276        egress_source: &str,
277        site_name: &str,
278    ) -> PartialEntry {
279        let mut params = PartialEntry::default();
280
281        // Apply basic/default configuration
282        if let Some(default) = self.by_domain.get("default") {
283            params.merge_from(default.clone());
284        }
285
286        // Provider rules come next.
287        // They can only match valid domain names, so we'll
288        // skip processing them if we have something like our
289        // TSA http "domain name" here
290        let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291        if is_domain_name {
292            let mut prov_with_sources = vec![];
293
294            for prov in self.by_provider.values() {
295                if prov.domain_matches(domain).await {
296                    toml_table_merge_from(&mut params.params, &prov.params);
297                    prov.apply_provider_params_to(egress_source, &mut params.params);
298
299                    if !prov.sources.is_empty() {
300                        // Remember this matching provider, so that we
301                        // can apply any source rules after we've applied
302                        // any/all base provider rules for this domain
303                        prov_with_sources.push(prov);
304                    }
305                }
306            }
307
308            // Then Provider source rules
309            for prov in prov_with_sources {
310                if let Some(source) = prov.sources.get(egress_source) {
311                    toml_table_merge_from(&mut params.params, source);
312                    prov.apply_provider_params_to(egress_source, &mut params.params);
313                }
314            }
315        }
316
317        // Then site config
318        if let Some(by_site) = self.by_site.get(site_name) {
319            params.merge_from(by_site.clone());
320        }
321
322        // Then domain config
323        if let Some(by_domain) = self.by_domain.get(domain) {
324            params.merge_from(by_domain.clone());
325        }
326
327        // Then source config for the site
328        if let Some(by_site) = self.by_site.get(site_name) {
329            if let Some(source) = by_site.sources.get(egress_source) {
330                toml_table_merge_from(&mut params.params, source);
331            }
332        }
333
334        // Then source config for the domain
335        if let Some(by_domain) = self.by_domain.get(domain) {
336            if let Some(source) = by_domain.sources.get(egress_source) {
337                toml_table_merge_from(&mut params.params, source);
338            }
339        }
340
341        params
342    }
343
344    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345        use rfc5321::ForwardPath;
346        // Extract the domain from the recipient.
347        let recipient = ForwardPath::try_from(
348            record
349                .recipient
350                .first()
351                .ok_or_else(|| anyhow::anyhow!("no recipients!?"))?
352                .as_str(),
353        )
354        .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
355
356        let recipient = match recipient {
357            ForwardPath::Postmaster => {
358                // It doesn't make sense to apply automation on the
359                // local postmaster address, so we ignore this.
360                return Ok(vec![]);
361            }
362            ForwardPath::Path(path) => path.mailbox,
363        };
364        let domain = recipient.domain.to_string();
365
366        // Track events/outcomes by site.
367        let source = record.egress_source.as_deref().unwrap_or("unspecified");
368        // record.site is poorly named; it is really an identifier for the
369        // egress path. For matching purposes, we want just the site_name
370        // in the form produced by our MX resolution process.
371        // In an earlier incarnation of this logic, we would resolve the
372        // site_name for ourselves based on other data in the record,
373        // but that could lead to over-resolution of some names and
374        // yield surprising results.
375        // What we do here is extract the egress path decoration from
376        // record.site to arrive at something that looks like the
377        // mx site_name.
378        // NOTE: this is coupled with the logic in
379        // ReadyQueueManager::compute_queue_name
380        let site_name = record
381            .site
382            .trim_start_matches(&format!("{source}->"))
383            .trim_end_matches("@smtp_client")
384            .to_string();
385
386        Ok(self.match_rules_impl(record, &domain, &site_name).await)
387    }
388
389    pub async fn match_rules_impl(
390        &self,
391        record: &JsonLogRecord,
392        domain: &str,
393        site_name: &str,
394    ) -> Vec<Rule> {
395        let mut result = vec![];
396        let response = record.response.to_single_line();
397        tracing::trace!("Consider rules for {response}");
398
399        let is_internal = record.response.content.starts_with("KumoMTA internal: ");
400
401        if let Some(default) = self.by_domain.get("default") {
402            for rule in &default.automation {
403                tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
404                if rule.matches(is_internal, &response) {
405                    // For automation under `default`, we always
406                    // assume that mx_rollup should be true.
407                    // If you somehow have a domain where that isn't
408                    // true, you should avoid using `default` for
409                    // automation.  Honestly, it's best to avoid
410                    // using `default` for automation.
411                    result.push(rule.clone_and_set_rollup());
412                }
413            }
414        }
415
416        for prov in self.by_provider.values() {
417            if prov.domain_matches(domain).await {
418                for rule in &prov.automation {
419                    tracing::trace!(
420                        "Consider provider \"{}\" rule {rule:?} for {response}",
421                        prov.provider_name
422                    );
423                    if rule.matches(is_internal, &response) {
424                        result.push(rule.clone());
425                    }
426                }
427            }
428        }
429
430        // Then site config
431        if let Some(by_site) = self.by_site.get(site_name) {
432            for rule in &by_site.automation {
433                tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
434                if rule.matches(is_internal, &response) {
435                    result.push(rule.clone_and_set_rollup());
436                }
437            }
438        }
439
440        // Then domain config
441        if let Some(by_domain) = self.by_domain.get(domain) {
442            for rule in &by_domain.automation {
443                tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
444                if rule.matches(is_internal, &response) {
445                    result.push(rule.clone());
446                }
447            }
448        }
449
450        result
451    }
452}
453
454#[cfg(feature = "lua")]
455#[derive(Debug, Default, Clone, mlua::FromLua)]
456pub struct Shaping {
457    inner: Arc<ShapingInner>,
458}
459
460#[cfg(feature = "lua")]
461fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
462    let d = &mut serde_json::Deserializer::from_str(json);
463    Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
468    let d = toml::Deserializer::parse(toml)?;
469    Ok(serde_path_to_error::deserialize(d)?)
470}
471
472#[cfg(feature = "lua")]
473#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
474pub enum CheckLevel {
475    #[default]
476    Ignore,
477    Warn,
478    Error,
479}
480#[cfg(feature = "lua")]
481impl std::str::FromStr for CheckLevel {
482    type Err = String;
483
484    fn from_str(s: &str) -> Result<Self, String> {
485        if s.eq_ignore_ascii_case("ignore") {
486            Ok(Self::Ignore)
487        } else if s.eq_ignore_ascii_case("warn") {
488            Ok(Self::Warn)
489        } else if s.eq_ignore_ascii_case("error") {
490            Ok(Self::Error)
491        } else {
492            Err(format!(
493                "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
494            ))
495        }
496    }
497}
498
499#[cfg(feature = "lua")]
500#[derive(Default)]
501struct Collector {
502    warnings: Vec<String>,
503    errors: Vec<String>,
504}
505
506#[cfg(feature = "lua")]
507impl Collector {
508    fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
509        match level {
510            CheckLevel::Ignore => {}
511            CheckLevel::Warn => self.warnings.push(msg.into()),
512            CheckLevel::Error => self.errors.push(msg.into()),
513        }
514    }
515}
516
517#[cfg(feature = "lua")]
518#[derive(Debug, Clone, Deserialize)]
519#[serde_as]
520#[serde(deny_unknown_fields)]
521pub struct ShapingMergeOptions {
522    #[serde(default)]
523    pub provider_overlap: CheckLevel,
524    #[serde(default)]
525    pub dns_fail: CheckLevel,
526    #[serde(default)]
527    pub null_mx: CheckLevel,
528    #[serde(default)]
529    pub aliased_site: CheckLevel,
530    #[serde(default)]
531    pub skip_remote: bool,
532    #[serde(default)]
533    pub remote_load: CheckLevel,
534    #[serde(default)]
535    pub local_load: CheckLevel,
536    #[serde(default, with = "duration_serde")]
537    pub http_timeout: Option<Duration>,
538}
539
540#[cfg(feature = "lua")]
541impl Default for ShapingMergeOptions {
542    fn default() -> Self {
543        Self {
544            provider_overlap: CheckLevel::Ignore,
545            dns_fail: CheckLevel::Ignore,
546            null_mx: CheckLevel::Ignore,
547            aliased_site: CheckLevel::Ignore,
548            skip_remote: false,
549            remote_load: CheckLevel::Ignore,
550            local_load: CheckLevel::Error,
551            http_timeout: None,
552        }
553    }
554}
555
556#[cfg(feature = "lua")]
557impl Shaping {
558    async fn load_from_file(
559        path: &str,
560        options: &ShapingMergeOptions,
561        collector: &mut Collector,
562    ) -> anyhow::Result<ShapingFile> {
563        let (data, level): (String, CheckLevel) =
564            if path.starts_with("http://") || path.starts_with("https://") {
565                if options.skip_remote {
566                    collector.push(
567                        CheckLevel::Warn,
568                        format!("Ignoring {path} because skip_remote is set to true"),
569                    );
570                    return Ok(ShapingFile::default());
571                }
572
573                // To facilitate startup ordering races, and listing multiple subscription
574                // host replicas and allowing one or more of them to be temporarily down,
575                // we allow the http request to fail.
576                // We'll log the error message but consider it to be an empty map
577
578                async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
579                    tokio::time::timeout(timeout, async {
580                        let response = reqwest::Client::builder()
581                            .timeout(timeout)
582                            .connect_timeout(timeout)
583                            .read_timeout(timeout)
584                            .build()?
585                            .get(url)
586                            .send()
587                            .await
588                            .with_context(|| format!("making HTTP request to {url}"))?;
589                        let status = response.status();
590                        if !status.is_success() {
591                            let err_text = match response.text().await {
592                                Ok(text) => text,
593                                Err(err) => {
594                                    format!("Additional error {err:#} while reading response body")
595                                }
596                            };
597                            anyhow::bail!(
598                                "HTTP request to {url} failed: status {status} {reason} {err_text}",
599                                status = status.as_u16(),
600                                reason = status.canonical_reason().unwrap_or("")
601                            );
602                        }
603                        response
604                            .text()
605                            .await
606                            .with_context(|| format!("reading text from {url}"))
607                    })
608                    .await
609                    .with_context(|| format!("timeout making HTTP request to {url}"))?
610                }
611
612                let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
613
614                match http_get(path, timeout).await {
615                    Ok(s) => (s, options.remote_load),
616                    Err(err) => {
617                        tracing::error!("{err:#}. Ignoring this shaping source for now");
618                        collector.push(
619                            options.remote_load,
620                            format!("remote shaping source {path} error: {err:#}"),
621                        );
622                        return Ok(ShapingFile::default());
623                    }
624                }
625            } else {
626                match std::fs::read_to_string(path)
627                    .with_context(|| format!("loading data from file {path}"))
628                {
629                    Err(err) => {
630                        collector.push(
631                            options.local_load,
632                            format!("local shaping source {path} error: {err:#}"),
633                        );
634                        return Ok(ShapingFile::default());
635                    }
636                    Ok(s) => (s, options.local_load),
637                }
638            };
639
640        if path.ends_with(".toml") {
641            from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
642        } else if path.ends_with(".json") {
643            from_json(&data).with_context(|| format!("parsing json from file {path}"))
644        } else {
645            // Try parsing both ways and see which wins
646            let mut errors = vec![];
647            match from_toml(&data) {
648                Ok(s) => return Ok(s),
649                Err(err) => errors.push(format!("as toml: {err:#}")),
650            }
651            match from_json(&data) {
652                Ok(s) => return Ok(s),
653                Err(err) => errors.push(format!("as json: {err:#}")),
654            }
655
656            collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
657            Ok(ShapingFile::default())
658        }
659    }
660
661    pub async fn merge_files(
662        files: &[String],
663        options: &ShapingMergeOptions,
664    ) -> anyhow::Result<Self> {
665        use futures_util::stream::FuturesUnordered;
666        use futures_util::StreamExt;
667
668        let mut collector = Collector::default();
669        let mut loaded = vec![];
670        for p in files {
671            loaded.push(Self::load_from_file(p, options, &mut collector).await?);
672        }
673
674        let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
675        let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
676        let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
677        let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
678
679        // Pre-resolve domains. We don't interleave the resolution with
680        // the work below, because we want to ensure that the ordering
681        // is preserved
682        let mut mx = std::collections::HashMap::new();
683        let mut lookups = FuturesUnordered::new();
684        for item in &loaded {
685            for (domain, partial) in &item.domains {
686                if partial.mx_rollup {
687                    let domain = domain.to_string();
688                    lookups.push(tokio::spawn(async move {
689                        let mx_result = MailExchanger::resolve(&domain).await;
690                        (domain, mx_result)
691                    }));
692                }
693            }
694        }
695
696        while let Some(Ok((domain, result))) = lookups.next().await {
697            mx.insert(domain, result);
698        }
699
700        for mut item in loaded {
701            if let Some(mut partial) = item.default.take() {
702                let domain = "default";
703                partial.domain_name.replace(domain.to_string());
704                match by_domain.get_mut(domain) {
705                    Some(existing) => {
706                        existing.merge_from(partial);
707                    }
708                    None => {
709                        by_domain.insert(domain.to_string(), partial);
710                    }
711                }
712            }
713
714            for (domain, mut partial) in item.domains {
715                partial.domain_name.replace(domain.clone());
716
717                if let Ok(name) = fully_qualify(&domain) {
718                    if name.num_labels() == 1 {
719                        collector.push(
720                            CheckLevel::Warn,
721                            format!(
722                                "Entry for domain '{domain}' consists of a \
723                                 single DNS label. Domain names in TOML sections \
724                                 need to be quoted like '[\"{domain}.com\"]` otherwise \
725                                 the '.' will create a nested table rather than being \
726                                 added to the domain name."
727                            ),
728                        );
729                    }
730                }
731
732                #[cfg(test)]
733                if partial._treat_domain_name_as_site_name {
734                    match by_site.get_mut(&domain) {
735                        Some(existing) => {
736                            existing.merge_from(partial);
737                        }
738                        None => {
739                            by_site.insert(domain.to_string(), partial);
740                        }
741                    }
742                    continue;
743                }
744                if partial.mx_rollup {
745                    let mx = match mx.get(&domain) {
746                        Some(Ok(mx)) => mx,
747                        Some(Err(err)) => {
748                            collector.push(
749                                options.dns_fail,
750                                format!(
751                                    "error resolving MX for {domain}: {err:#}. \
752                                 Ignoring the shaping config for that domain."
753                                ),
754                            );
755                            continue;
756                        }
757                        None => {
758                            collector.push(
759                                options.dns_fail,
760                                format!(
761                                "We didn't try to resolve the MX for {domain} for some reason!?. \
762                                 Ignoring the shaping config for that domain."
763                            ),
764                            );
765                            continue;
766                        }
767                    };
768
769                    if mx.site_name.is_empty() {
770                        collector.push(
771                            options.null_mx,
772                            format!(
773                            "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
774                             Ignoring the shaping config for that domain."),
775                        );
776                        continue;
777                    }
778
779                    site_aliases
780                        .entry(mx.site_name.to_string())
781                        .or_default()
782                        .push(domain.to_string());
783
784                    match by_site.get_mut(&mx.site_name) {
785                        Some(existing) => {
786                            existing.merge_from(partial);
787                        }
788                        None => {
789                            by_site.insert(mx.site_name.clone(), partial);
790                        }
791                    }
792                } else {
793                    match by_domain.get_mut(&domain) {
794                        Some(existing) => {
795                            existing.merge_from(partial);
796                        }
797                        None => {
798                            by_domain.insert(domain, partial);
799                        }
800                    }
801                }
802            }
803
804            for (provider, mut prov) in item.provider {
805                prov.provider_name = provider.to_string();
806                match by_provider.get_mut(&provider) {
807                    Some(existing) => {
808                        existing.merge_from(prov);
809                    }
810                    None => {
811                        by_provider.insert(provider.to_string(), prov);
812                    }
813                }
814            }
815        }
816
817        for (site, partial) in &by_site {
818            partial
819                .clone()
820                .finish()
821                .with_context(|| format!("site: {site}"))?;
822        }
823
824        for (domain, partial) in &by_domain {
825            partial
826                .clone()
827                .finish()
828                .with_context(|| format!("domain: {domain}"))?;
829        }
830
831        for (provider, prov) in &by_provider {
832            prov.finish_params()
833                .with_context(|| format!("provider: {provider}"))?;
834        }
835
836        if options.aliased_site != CheckLevel::Ignore {
837            for (site, aliases) in site_aliases {
838                if aliases.len() > 1 {
839                    collector.push(
840                        options.aliased_site,
841                        format!(
842                            "multiple domain blocks alias to the same site: {site}: {}",
843                            aliases.join(", ")
844                        ),
845                    );
846                }
847            }
848        }
849
850        if options.provider_overlap != CheckLevel::Ignore {
851            for domain in mx.keys() {
852                let mut matching_providers = vec![];
853                for (prov_name, prov) in &by_provider {
854                    if prov.domain_matches(domain).await {
855                        matching_providers.push(prov_name.to_string());
856                    }
857                }
858                if !matching_providers.is_empty() {
859                    collector.push(
860                        options.provider_overlap,
861                        format!(
862                            "domain {domain} is also matched by provider(s): {}",
863                            matching_providers.join(", ")
864                        ),
865                    );
866                }
867            }
868        }
869
870        let mut ctx = Sha256::new();
871        ctx.update("by_site");
872        for (site, entry) in &by_site {
873            ctx.update(site);
874            entry.hash_into(&mut ctx);
875        }
876        ctx.update("by_domain");
877        for (domain, entry) in &by_domain {
878            ctx.update(domain);
879            entry.hash_into(&mut ctx);
880        }
881        ctx.update("by_provider");
882        for (provider, prov) in &by_provider {
883            ctx.update(provider);
884            prov.hash_into(&mut ctx);
885        }
886        ctx.update("warnings");
887        for warn in &collector.warnings {
888            ctx.update(warn);
889        }
890        ctx.update("errors");
891        for err in &collector.errors {
892            ctx.update(err);
893        }
894        let hash = ctx.finalize();
895        let hash = data_encoding::HEXLOWER.encode(&hash);
896
897        Ok(Self {
898            inner: Arc::new(ShapingInner {
899                by_site,
900                by_domain,
901                by_provider,
902                warnings: collector.warnings,
903                errors: collector.errors,
904                hash,
905            }),
906        })
907    }
908
909    async fn get_egress_path_config(
910        &self,
911        domain: &str,
912        egress_source: &str,
913        site_name: &str,
914    ) -> PartialEntry {
915        self.inner
916            .get_egress_path_config(domain, egress_source, site_name)
917            .await
918    }
919
920    pub async fn get_egress_path_config_value(
921        &self,
922        domain: &str,
923        egress_source: &str,
924        site_name: &str,
925    ) -> anyhow::Result<serde_json::Value> {
926        let partial = self
927            .get_egress_path_config(domain, egress_source, site_name)
928            .await;
929        Ok(serde_json::to_value(&partial)?)
930    }
931
932    pub fn get_errors(&self) -> &[String] {
933        &self.inner.errors
934    }
935
936    pub fn get_warnings(&self) -> &[String] {
937        &self.inner.warnings
938    }
939
940    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
941        self.inner.match_rules(record).await
942    }
943
944    pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
945        let mut result = BTreeMap::new();
946
947        for (site_name, site) in &self.inner.by_site {
948            for source_name in site.sources.keys() {
949                result
950                    .entry(source_name.to_string())
951                    .or_insert(vec![])
952                    .push(format!("site:{site_name}"));
953            }
954        }
955        for (domain_name, domain) in &self.inner.by_domain {
956            for source_name in domain.sources.keys() {
957                result
958                    .entry(source_name.to_string())
959                    .or_insert(vec![])
960                    .push(format!("domain:{domain_name}"));
961            }
962        }
963
964        result
965    }
966
967    pub fn hash(&self) -> String {
968        self.inner.hash.clone()
969    }
970}
971
972#[cfg(feature = "lua")]
973impl LuaUserData for Shaping {
974    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
975        mod_memoize::Memoized::impl_memoize(methods);
976        methods.add_async_method(
977            "get_egress_path_config",
978            |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
979                let params = this
980                    .get_egress_path_config(&domain, &egress_source, &site_name)
981                    .await;
982                lua.to_value_with(&params.params, serialize_options())
983            },
984        );
985
986        methods.add_method("get_errors", move |_lua, this, ()| {
987            let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
988            Ok(errors)
989        });
990
991        methods.add_method("get_warnings", move |_lua, this, ()| {
992            let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
993            Ok(warnings)
994        });
995
996        methods.add_method("get_referenced_sources", move |_lua, this, ()| {
997            Ok(this.get_referenced_sources())
998        });
999
1000        methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
1001            let record: JsonLogRecord = lua.from_value(record)?;
1002            let rules = this.match_rules(&record).await.map_err(any_err)?;
1003            let mut result = vec![];
1004            for rule in rules {
1005                result.push(lua.to_value(&rule)?);
1006            }
1007            Ok(result)
1008        });
1009
1010        methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
1011    }
1012}
1013
1014#[derive(Default, Debug)]
1015pub struct MergedEntry {
1016    pub params: EgressPathConfig,
1017    pub sources: OrderMap<String, EgressPathConfig>,
1018    pub automation: Vec<Rule>,
1019}
1020
1021#[cfg(feature = "lua")]
1022#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1023struct ShapingFile {
1024    pub default: Option<PartialEntry>,
1025    #[serde(flatten, default)]
1026    pub domains: OrderMap<String, PartialEntry>,
1027    #[serde(default)]
1028    pub provider: OrderMap<String, ProviderEntry>,
1029}
1030
1031#[cfg(feature = "lua")]
1032#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1033struct PartialEntry {
1034    #[serde(skip)]
1035    pub domain_name: Option<String>,
1036
1037    #[serde(flatten)]
1038    pub params: toml::Table,
1039
1040    #[serde(default = "default_true")]
1041    pub mx_rollup: bool,
1042
1043    // This is present to facilitate unit testing without requiring
1044    // DNS to resolve the site_name. When set to true, the domain_name
1045    // is considered to be the site_name for this entry.
1046    #[cfg(test)]
1047    #[serde(default)]
1048    pub _treat_domain_name_as_site_name: bool,
1049
1050    #[serde(default)]
1051    pub replace_base: bool,
1052
1053    #[serde(default)]
1054    pub automation: Vec<Rule>,
1055
1056    #[serde(default)]
1057    pub sources: OrderMap<String, toml::Table>,
1058}
1059
1060#[cfg(feature = "lua")]
1061#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1062pub struct ProviderEntry {
1063    #[serde(skip, default)]
1064    pub provider_name: String,
1065
1066    #[serde(default)]
1067    pub provider_connection_limit: Option<LimitSpec>,
1068
1069    #[serde(default)]
1070    pub provider_max_message_rate: Option<ThrottleSpec>,
1071
1072    #[serde(default, rename = "match")]
1073    pub matches: Vec<ProviderMatch>,
1074
1075    #[serde(default)]
1076    pub replace_base: bool,
1077
1078    #[serde(flatten)]
1079    pub params: toml::Table,
1080
1081    #[serde(default)]
1082    pub automation: Vec<Rule>,
1083
1084    #[serde(default)]
1085    pub sources: OrderMap<String, toml::Table>,
1086}
1087
1088#[cfg(feature = "lua")]
1089fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1090    // Remove trailing dot from candidate, as our resolver tends
1091    // to leave the canonical dot on the input host name
1092    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1093    // We DON'T handle case mismatches in this code.
1094    // We assume that the dns-resolver crate normalized the
1095    // names to lowercase and that the input suffix is lowercase as well.
1096    candidate.ends_with(suffix)
1097}
1098
1099#[cfg(feature = "lua")]
1100#[cfg(test)]
1101#[test]
1102fn test_suffix_matches() {
1103    assert!(suffix_matches("a", "a"));
1104    assert!(suffix_matches("foo.com", "foo.com"));
1105    assert!(!suffix_matches("foo.com", ".foo.com"));
1106    assert!(!suffix_matches("foo.com", "longer.com"));
1107    assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1108
1109    // We DON'T handle case mismatches in this code.
1110    // We assume that the dns-resolver crate normalized the
1111    // names to lowercase and that the input suffix is lowercase as well.
1112    assert!(!suffix_matches("foo.Com", ".com"));
1113    assert!(!suffix_matches("foo.Cam", ".com"));
1114}
1115
1116#[cfg(feature = "lua")]
1117fn host_matches(candidate: &str, name: &str) -> bool {
1118    // Remove trailing dot from candidate, as our resolver tends
1119    // to leave the canonical dot on the input host name
1120    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1121    candidate == name
1122}
1123
1124#[cfg(feature = "lua")]
1125#[cfg(test)]
1126#[test]
1127fn test_host_matches() {
1128    assert!(host_matches("foo.com", "foo.com"));
1129    assert!(host_matches("foo.com.", "foo.com"));
1130    assert!(!host_matches("foo.com", "notfoo.com"));
1131}
1132
1133#[cfg(feature = "lua")]
1134impl ProviderEntry {
1135    async fn domain_matches(&self, domain: &str) -> bool {
1136        // We'd like to avoid doing DNS if we can do a simple suffix match,
1137        // so we bias to looking at those first
1138        let mut need_mx = false;
1139
1140        tracing::trace!(
1141            "ProviderEntry::domain_matches({domain}) vs {:?}",
1142            self.matches
1143        );
1144
1145        for rule in &self.matches {
1146            match rule {
1147                ProviderMatch::DomainSuffix(suffix) => {
1148                    if suffix_matches(domain, suffix) {
1149                        tracing::trace!("{domain} suffix matches {suffix}");
1150                        return true;
1151                    }
1152                }
1153                ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1154                    need_mx = true;
1155                }
1156            }
1157        }
1158
1159        if !need_mx {
1160            return false;
1161        }
1162
1163        // Now we can consider DNS
1164        match MailExchanger::resolve(domain).await {
1165            Err(err) => {
1166                // Didn't resolve; could be legit, for example, could
1167                // be an internal or fake name that is handled via
1168                // smart host or custom routing, so we log what
1169                // happened as a trace rather than polluting the
1170                // logs about it
1171                tracing::trace!(
1172                    "Error resolving MX for {domain}: {err:#}. \
1173                    Provider {} match rules will be ignored",
1174                    self.provider_name
1175                );
1176                false
1177            }
1178            Ok(mx) => {
1179                tracing::trace!("Consider MXSuffix rules");
1180                for host in &mx.hosts {
1181                    let mut matched = false;
1182
1183                    for rule in &self.matches {
1184                        match rule {
1185                            ProviderMatch::MXSuffix(suffix) => {
1186                                // For a given MX suffix rule, all hosts must match
1187                                // it for it to be valid. This is so that we don't
1188                                // falsely lump a vanity domain that blends providers
1189                                // together.
1190                                tracing::trace!("suffix={suffix} vs host {host}");
1191                                if suffix_matches(host, suffix) {
1192                                    matched = true;
1193                                    break;
1194                                }
1195                            }
1196                            ProviderMatch::HostName(name) => {
1197                                if host_matches(host, name) {
1198                                    matched = true;
1199                                    break;
1200                                }
1201                            }
1202                            ProviderMatch::DomainSuffix(_) => {}
1203                        }
1204                    }
1205
1206                    if !matched {
1207                        tracing::trace!("host didn't match any of these rules");
1208                        return false;
1209                    }
1210                }
1211
1212                true
1213            }
1214        }
1215    }
1216
1217    fn merge_from(&mut self, mut other: Self) {
1218        if other.replace_base {
1219            self.provider_connection_limit = other.provider_connection_limit;
1220            self.matches = other.matches;
1221            self.params = other.params;
1222            self.sources = other.sources;
1223            self.automation = other.automation;
1224        } else {
1225            if other.provider_connection_limit.is_some() {
1226                self.provider_connection_limit = other.provider_connection_limit;
1227            }
1228
1229            toml_table_merge_from(&mut self.params, &other.params);
1230
1231            for (source, tbl) in other.sources {
1232                match self.sources.get_mut(&source) {
1233                    Some(existing) => {
1234                        toml_table_merge_from(existing, &tbl);
1235                    }
1236                    None => {
1237                        self.sources.insert(source, tbl);
1238                    }
1239                }
1240            }
1241
1242            self.matches.append(&mut other.matches);
1243            self.automation.append(&mut other.automation);
1244        }
1245    }
1246
1247    fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1248        let mut implied = toml::Table::new();
1249        implied.insert(
1250            "provider_name".to_string(),
1251            toml::Value::String(self.provider_name.to_string()),
1252        );
1253
1254        if let Some(limit) = &self.provider_connection_limit {
1255            let mut limits = toml::Table::new();
1256            limits.insert(
1257                format!("shaping-provider-{}-{source}-limit", self.provider_name),
1258                toml::Value::String(limit.to_string()),
1259            );
1260            implied.insert(
1261                "additional_connection_limits".to_string(),
1262                toml::Value::Table(limits),
1263            );
1264        }
1265        if let Some(rate) = &self.provider_max_message_rate {
1266            let rate = rate.as_string();
1267            let mut limits = toml::Table::new();
1268            limits.insert(
1269                format!("shaping-provider-{}-{source}-rate", self.provider_name),
1270                rate.into(),
1271            );
1272            implied.insert(
1273                "additional_message_rate_throttles".to_string(),
1274                toml::Value::Table(limits),
1275            );
1276        }
1277
1278        if let Some(rate) = target.remove("provider_source_selection_rate") {
1279            let mut limits = toml::Table::new();
1280            limits.insert(
1281                format!(
1282                    "shaping-provider-{}-{source}-selection-rate",
1283                    self.provider_name
1284                ),
1285                rate,
1286            );
1287            implied.insert(
1288                "additional_source_selection_rates".to_string(),
1289                toml::Value::Table(limits),
1290            );
1291        }
1292
1293        toml_table_merge_from(target, &implied);
1294    }
1295
1296    fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1297        let provider_name = &self.provider_name;
1298
1299        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1300            format!(
1301                "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1302                self.params
1303            )
1304        })?;
1305        let mut sources = OrderMap::new();
1306
1307        for (source, params) in &self.sources {
1308            let mut params = params.clone();
1309            // I don't really like this remove call. The issue is that we don't
1310            // have an alternative way to filter this out of the partial source
1311            // definition, and this provider_ option is not valid in the
1312            // EgressPathConfig struct itself; it is a shaping source-specific
1313            // addition to help with setting up shared throttles across providers
1314            params.remove("provider_source_selection_rate");
1315            sources.insert(
1316                source.clone(),
1317                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1318                    format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1319                })?,
1320            );
1321        }
1322
1323        Ok(MergedEntry {
1324            params,
1325            sources,
1326            automation: self.automation.clone(),
1327        })
1328    }
1329
1330    fn hash_into(&self, ctx: &mut Sha256) {
1331        ctx.update(&self.provider_name);
1332        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1333    }
1334}
1335
1336#[cfg(feature = "lua")]
1337#[derive(Deserialize, Serialize, Debug, Clone)]
1338pub enum ProviderMatch {
1339    MXSuffix(String),
1340    DomainSuffix(String),
1341    HostName(String),
1342}
1343
1344#[cfg(feature = "lua")]
1345fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1346    // Limit merging to just the throttle related fields, as their purpose
1347    // is for creating broader scoped limits that cut across normal boundaries
1348    fn is_mergeable(name: &str) -> bool {
1349        match name {
1350            "additional_connection_limits"
1351            | "additional_message_rate_throttles"
1352            | "additional_source_selection_rates" => true,
1353            _ => false,
1354        }
1355    }
1356
1357    for (k, v) in source {
1358        match (tbl.get_mut(k), v.as_table()) {
1359            // Merge Table values together, rather than simply replacing them.
1360            (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1361                for (inner_k, inner_v) in v {
1362                    existing.insert(inner_k.clone(), inner_v.clone());
1363                }
1364            }
1365            _ => {
1366                tbl.insert(k.clone(), v.clone());
1367            }
1368        }
1369    }
1370}
1371
1372#[cfg(feature = "lua")]
1373impl PartialEntry {
1374    fn merge_from(&mut self, mut other: Self) {
1375        if other.replace_base {
1376            self.params = other.params;
1377            self.automation = other.automation;
1378            self.sources = other.sources;
1379        } else {
1380            toml_table_merge_from(&mut self.params, &other.params);
1381
1382            for (source, tbl) in other.sources {
1383                match self.sources.get_mut(&source) {
1384                    Some(existing) => {
1385                        toml_table_merge_from(existing, &tbl);
1386                    }
1387                    None => {
1388                        self.sources.insert(source, tbl);
1389                    }
1390                }
1391            }
1392
1393            self.automation.append(&mut other.automation);
1394        }
1395    }
1396
1397    fn finish(self) -> anyhow::Result<MergedEntry> {
1398        let domain = self.domain_name.unwrap_or_default();
1399
1400        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1401            format!(
1402                "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1403                self.params
1404            )
1405        })?;
1406        let mut sources = OrderMap::new();
1407
1408        for (source, params) in self.sources {
1409            sources.insert(
1410                source.clone(),
1411                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1412                    format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1413                })?,
1414            );
1415        }
1416
1417        Ok(MergedEntry {
1418            params,
1419            sources,
1420            automation: self.automation,
1421        })
1422    }
1423
1424    fn hash_into(&self, ctx: &mut Sha256) {
1425        if let Some(name) = self.domain_name.as_ref() {
1426            ctx.update(name)
1427        }
1428        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1429    }
1430}
1431
1432fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1433where
1434    T: Deserialize<'de>,
1435    D: Deserializer<'de>,
1436{
1437    let result: Result<Vec<T>, _> =
1438        OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1439    match result {
1440        Ok(r) => Ok(r),
1441        Err(err) => Err(serde::de::Error::custom(format!(
1442            "{expecting}.\nThe underlying error message is:\n{err:#}"
1443        ))),
1444    }
1445}
1446
1447fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1448where
1449    D: Deserializer<'de>,
1450{
1451    one_or_many(
1452        deserializer,
1453        "\"action\" field expected either a single Action or an array of Actions",
1454    )
1455}
1456
1457fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1458where
1459    D: Deserializer<'de>,
1460{
1461    string_or_array(
1462        deserializer,
1463        "regex string or array of regex strings for field regex",
1464    )
1465}
1466
1467fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1468where
1469    T: Deserialize<'de> + TryFrom<String>,
1470    <T as TryFrom<String>>::Error: std::fmt::Debug,
1471    D: Deserializer<'de>,
1472{
1473    // This is a Visitor that forwards string types to T's `TryFrom<String>` impl and
1474    // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
1475    // keep the compiler from complaining about T being an unused generic type
1476    // parameter. We need T in order to know the Value type for the Visitor
1477    // impl.
1478    struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1479
1480    impl<'de, T> Visitor<'de> for StringOrArray<T>
1481    where
1482        T: Deserialize<'de> + TryFrom<String>,
1483        <T as TryFrom<String>>::Error: std::fmt::Debug,
1484    {
1485        type Value = Vec<T>;
1486
1487        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1488            formatter.write_str(self.1)
1489        }
1490
1491        fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1492        where
1493            E: serde::de::Error,
1494        {
1495            Ok(vec![
1496                T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1497            ])
1498        }
1499
1500        fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1501        where
1502            S: SeqAccess<'de>,
1503        {
1504            Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1505        }
1506    }
1507
1508    deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1509}
1510
1511#[cfg(feature = "lua")]
1512fn default_true() -> bool {
1513    true
1514}
1515
1516#[cfg(feature = "lua")]
1517pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1518    let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1519
1520    shaping_mod.set(
1521        "load",
1522        lua.create_async_function(
1523            move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1524                let options = match options {
1525                    Some(v) => lua.from_value(v)?,
1526                    None => Default::default(),
1527                };
1528                let shaping = Shaping::merge_files(&paths, &options)
1529                    .await
1530                    .map_err(any_err)?;
1531                Ok(shaping)
1532            },
1533        )?,
1534    )?;
1535
1536    Ok(())
1537}
1538
1539#[cfg(test)]
1540mod test {
1541    use super::*;
1542    use kumo_log_types::RecordType;
1543    use rfc5321::Response;
1544    use std::io::Write;
1545    use tempfile::NamedTempFile;
1546    use uuid::Uuid;
1547
1548    async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1549        let mut files = vec![];
1550        let mut file_names = vec![];
1551
1552        for (i, content) in inputs.iter().enumerate() {
1553            let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1554            shaping_file.write_all(content.as_bytes()).unwrap();
1555            file_names.push(shaping_file.path().to_str().unwrap().to_string());
1556            files.push(shaping_file);
1557        }
1558
1559        Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1560            .await
1561            .unwrap()
1562    }
1563
1564    #[tokio::test]
1565    async fn test_merge_additional() {
1566        let shaping = make_shaping_configs(&[
1567            r#"
1568["example.com"]
1569mx_rollup = false
1570additional_connection_limits = {"first"=10}
1571        "#,
1572            r#"
1573["example.com"]
1574mx_rollup = false
1575additional_connection_limits = {"second"=32}
1576additional_message_rate_throttles = {"second"="100/hr"}
1577        "#,
1578        ])
1579        .await;
1580
1581        let resolved = shaping
1582            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1583            .await
1584            .finish()
1585            .unwrap();
1586
1587        k9::snapshot!(
1588            resolved.params.additional_connection_limits,
1589            r#"
1590{
1591    "first": 10,
1592    "second": 32,
1593}
1594"#
1595        );
1596        k9::snapshot!(
1597            resolved.params.additional_message_rate_throttles,
1598            r#"
1599{
1600    "second": 100/h,
1601}
1602"#
1603        );
1604    }
1605
1606    #[tokio::test]
1607    async fn test_provider_multi_hostname() {
1608        let shaping = make_shaping_configs(&[r#"
1609[provider."yahoo"]
1610match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1611enable_tls = "Required"
1612        "#])
1613        .await;
1614
1615        let resolved = shaping
1616            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1617            .await
1618            .finish()
1619            .unwrap();
1620
1621        k9::assert_equal!(
1622            resolved.params.enable_tls,
1623            crate::egress_path::Tls::Required
1624        );
1625        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1626    }
1627
1628    #[tokio::test]
1629    async fn test_provider_multi_suffix() {
1630        let shaping = make_shaping_configs(&[r#"
1631[provider."yahoo"]
1632match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1633enable_tls = "Required"
1634        "#])
1635        .await;
1636
1637        let resolved = shaping
1638            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1639            .await
1640            .finish()
1641            .unwrap();
1642
1643        k9::assert_equal!(
1644            resolved.params.enable_tls,
1645            crate::egress_path::Tls::Required
1646        );
1647        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1648    }
1649
1650    #[tokio::test]
1651    async fn test_provider() {
1652        let shaping = make_shaping_configs(&[r#"
1653[provider."Office 365"]
1654match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1655enable_tls = "Required"
1656provider_connection_limit = 10
1657provider_max_message_rate = "120/s"
1658
1659[provider."Office 365".sources."new-source"]
1660provider_source_selection_rate = "500/d,max_burst=1"
1661        "#])
1662        .await;
1663
1664        let resolved = shaping
1665            .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1666            .await
1667            .finish()
1668            .unwrap();
1669
1670        k9::assert_equal!(
1671            resolved.params.enable_tls,
1672            crate::egress_path::Tls::Required
1673        );
1674        k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1675
1676        k9::snapshot!(
1677            resolved.params.additional_connection_limits,
1678            r#"
1679{
1680    "shaping-provider-Office 365-invalid.source-limit": 10,
1681}
1682"#
1683        );
1684        k9::snapshot!(
1685            resolved.params.additional_message_rate_throttles,
1686            r#"
1687{
1688    "shaping-provider-Office 365-invalid.source-rate": 120/s,
1689}
1690"#
1691        );
1692        assert!(resolved.params.source_selection_rate.is_none());
1693        assert!(resolved.params.additional_source_selection_rates.is_empty());
1694
1695        let resolved = shaping
1696            .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1697            .await
1698            .finish()
1699            .unwrap();
1700        assert!(resolved.params.source_selection_rate.is_none());
1701        k9::snapshot!(
1702            resolved.params.additional_source_selection_rates,
1703            r#"
1704{
1705    "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1706}
1707"#
1708        );
1709    }
1710
1711    #[tokio::test]
1712    async fn test_rule_matching() {
1713        let shaping = make_shaping_configs(&[r#"
1714[["default".automation]]
1715regex="default"
1716action = {SetConfig={name="connection_limit", value=1}}
1717duration = "1hr"
1718
1719["fake.site"]
1720_treat_domain_name_as_site_name = true
1721
1722[["fake.site".automation]]
1723regex="fake_rollup"
1724action = {SetConfig={name="connection_limit", value=2}}
1725duration = "1hr"
1726
1727["woot.provider"]
1728mx_rollup = false
1729
1730[["woot.provider".automation]]
1731regex="woot_domain"
1732action = {SetConfig={name="connection_limit", value=2}}
1733duration = "1hr"
1734
1735[provider."provider"]
1736match=[{DomainSuffix=".provider"}]
1737
1738[[provider."provider".automation]]
1739regex="provider"
1740action = {SetConfig={name="connection_limit", value=3}}
1741duration = "1hr"
1742match_internal = true
1743
1744"#])
1745        .await;
1746
1747        eprintln!("{:?}", shaping.inner.warnings);
1748
1749        fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1750            JsonLogRecord {
1751                kind: RecordType::TransientFailure,
1752                id: String::new(),
1753                sender: String::new(),
1754                recipient: vec![recipient.to_string()],
1755                queue: String::new(),
1756                site: site.to_string(),
1757                size: 0,
1758                response: Response {
1759                    code: 400,
1760                    command: None,
1761                    enhanced_code: None,
1762                    content: content.to_string(),
1763                },
1764                peer_address: None,
1765                timestamp: Default::default(),
1766                created: Default::default(),
1767                num_attempts: 1,
1768                bounce_classification: Default::default(),
1769                egress_pool: None,
1770                egress_source: None,
1771                source_address: None,
1772                feedback_report: None,
1773                meta: Default::default(),
1774                headers: Default::default(),
1775                delivery_protocol: None,
1776                reception_protocol: None,
1777                nodeid: Uuid::default(),
1778                tls_cipher: None,
1779                tls_protocol_version: None,
1780                tls_peer_subject_name: None,
1781                provider_name: None,
1782                session_id: None,
1783            }
1784        }
1785
1786        let matches = shaping
1787            .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1788            .await
1789            .unwrap();
1790        k9::assert_equal!(
1791            matches[0].regex[0].to_string(),
1792            "default",
1793            "matches against default automation rule"
1794        );
1795
1796        let matches = shaping
1797            .match_rules(&make_record(
1798                "KumoMTA internal: default",
1799                "user@example.com",
1800                "dummy_site",
1801            ))
1802            .await
1803            .unwrap();
1804        assert!(matches.is_empty(), "internal bounce should not match");
1805
1806        let matches = shaping
1807            .match_rules(&make_record(
1808                "woot_domain",
1809                "user@woot.provider",
1810                "dummy_site",
1811            ))
1812            .await
1813            .unwrap();
1814        k9::assert_equal!(
1815            matches[0].regex[0].to_string(),
1816            "woot_domain",
1817            "matches against domain rule with mx_rollup=false"
1818        );
1819
1820        let matches = shaping
1821            .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1822            .await
1823            .unwrap();
1824        k9::assert_equal!(
1825            matches[0].regex[0].to_string(),
1826            "fake_rollup",
1827            "matches against domain rule with mx_rollup=true"
1828        );
1829
1830        let matches = shaping
1831            .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1832            .await
1833            .unwrap();
1834        k9::assert_equal!(
1835            matches[0].regex[0].to_string(),
1836            "provider",
1837            "matches against provider rule"
1838        );
1839
1840        let matches = shaping
1841            .match_rules(&make_record(
1842                "KumoMTA internal: provider",
1843                "user@woot.provider",
1844                "dummy_site",
1845            ))
1846            .await
1847            .unwrap();
1848        k9::assert_equal!(
1849            matches[0].regex[0].to_string(),
1850            "provider",
1851            "internal response matches against provider rule"
1852        );
1853    }
1854
1855    #[tokio::test]
1856    async fn test_defaults() {
1857        let shaping = make_shaping_configs(&[
1858            r#"
1859["default"]
1860connection_limit = 10
1861max_connection_rate = "100/min"
1862max_deliveries_per_connection = 100
1863max_message_rate = "100/s"
1864idle_timeout = "60s"
1865data_timeout = "30s"
1866data_dot_timeout = "60s"
1867enable_tls = "Opportunistic"
1868consecutive_connection_failures_before_delay = 100
1869
1870[["default".automation]]
1871regex=[
1872        '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1873        '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1874        '/has been temporarily rate limited due to IP reputation/',
1875        '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1876        '/Server busy\. Please try again later from/'
1877]
1878action = [
1879        {SetConfig={name="max_message_rate", value="1/minute"}},
1880        {SetConfig={name="connection_limit", value=1}}
1881]
1882duration = "90m"
1883
1884[["default".automation]]
1885regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1886action = {SetConfig={name="enable_tls", value="Disabled"}}
1887duration = "30 days"
1888
1889["gmail.com"]
1890max_deliveries_per_connection = 50
1891connection_limit = 5
1892enable_tls = "Required"
1893consecutive_connection_failures_before_delay = 5
1894
1895["yahoo.com"]
1896max_deliveries_per_connection = 20
1897
1898[["yahoo.com".automation]]
1899regex = "\\[TS04\\]"
1900action = "Suspend"
1901duration = "2 hours"
1902
1903["comcast.net"]
1904connection_limit = 25
1905max_deliveries_per_connection = 250
1906enable_tls = "Required"
1907idle_timeout = "30s"
1908consecutive_connection_failures_before_delay = 24
1909
1910["mail.com"]
1911max_deliveries_per_connection = 100
1912
1913["orange.fr"]
1914connection_limit = 3
1915
1916["smtp.mailgun.com"]
1917connection_limit = 7000
1918max_deliveries_per_connection = 3
1919
1920["example.com"]
1921mx_rollup = false
1922max_deliveries_per_connection = 100
1923connection_limit = 3
1924
1925["example.com".sources."my source name"]
1926connection_limit = 5
1927        "#,
1928        ])
1929        .await;
1930
1931        let default = shaping
1932            .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1933            .await
1934            .finish()
1935            .unwrap();
1936        k9::snapshot!(
1937            default,
1938            r#"
1939MergedEntry {
1940    params: EgressPathConfig {
1941        connection_limit: 10,
1942        additional_connection_limits: {},
1943        enable_tls: Opportunistic,
1944        enable_mta_sts: true,
1945        enable_dane: false,
1946        enable_pipelining: true,
1947        enable_rset: true,
1948        tls_prefer_openssl: false,
1949        tls_certificate: None,
1950        tls_private_key: None,
1951        openssl_cipher_list: None,
1952        openssl_cipher_suites: None,
1953        openssl_options: None,
1954        rustls_cipher_suites: [],
1955        client_timeouts: SmtpClientTimeouts {
1956            connect_timeout: 60s,
1957            banner_timeout: 60s,
1958            ehlo_timeout: 300s,
1959            mail_from_timeout: 300s,
1960            rcpt_to_timeout: 300s,
1961            data_timeout: 30s,
1962            data_dot_timeout: 60s,
1963            rset_timeout: 5s,
1964            idle_timeout: 60s,
1965            starttls_timeout: 5s,
1966            auth_timeout: 60s,
1967        },
1968        system_shutdown_timeout: None,
1969        max_ready: 1024,
1970        consecutive_connection_failures_before_delay: 100,
1971        smtp_port: 25,
1972        smtp_auth_plain_username: None,
1973        smtp_auth_plain_password: None,
1974        allow_smtp_auth_plain_without_tls: false,
1975        max_message_rate: Some(
1976            100/s,
1977        ),
1978        additional_message_rate_throttles: {},
1979        source_selection_rate: None,
1980        additional_source_selection_rates: {},
1981        max_connection_rate: Some(
1982            100/m,
1983        ),
1984        max_deliveries_per_connection: 100,
1985        max_recipients_per_batch: 100,
1986        prohibited_hosts: {
1987            "0.0.0.0",
1988            "127.0.0.0/8",
1989            "::/127",
1990        },
1991        skip_hosts: {},
1992        ip_lookup_strategy: Ipv4AndIpv6,
1993        ehlo_domain: None,
1994        aggressive_connection_opening: false,
1995        refresh_interval: 60s,
1996        refresh_strategy: Ttl,
1997        dispatcher_wakeup_strategy: Aggressive,
1998        maintainer_wakeup_strategy: Aggressive,
1999        provider_name: None,
2000        remember_broken_tls: None,
2001        opportunistic_tls_reconnect_on_failed_handshake: false,
2002        use_lmtp: false,
2003        reconnect_strategy: ConnectNextHost,
2004        readyq_pool_name: None,
2005        low_memory_reduction_policy: ShrinkDataAndMeta,
2006        no_memory_reduction_policy: ShrinkDataAndMeta,
2007        try_next_host_on_transport_error: false,
2008        ignore_8bit_checks: false,
2009    },
2010    sources: {},
2011    automation: [
2012        Rule {
2013            regex: [
2014                Regex(
2015                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2016                ),
2017                Regex(
2018                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2019                ),
2020                Regex(
2021                    /has been temporarily rate limited due to IP reputation/,
2022                ),
2023                Regex(
2024                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2025                ),
2026                Regex(
2027                    /Server busy\. Please try again later from/,
2028                ),
2029            ],
2030            action: [
2031                SetConfig(
2032                    EgressPathConfigValue {
2033                        name: "max_message_rate",
2034                        value: HashableTomlValue {
2035                            value: String(
2036                                "1/minute",
2037                            ),
2038                        },
2039                    },
2040                ),
2041                SetConfig(
2042                    EgressPathConfigValue {
2043                        name: "connection_limit",
2044                        value: HashableTomlValue {
2045                            value: Integer(
2046                                1,
2047                            ),
2048                        },
2049                    },
2050                ),
2051            ],
2052            trigger: Immediate,
2053            duration: 5400s,
2054            was_rollup: false,
2055            match_internal: false,
2056        },
2057        Rule {
2058            regex: [
2059                Regex(
2060                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2061                ),
2062            ],
2063            action: [
2064                SetConfig(
2065                    EgressPathConfigValue {
2066                        name: "enable_tls",
2067                        value: HashableTomlValue {
2068                            value: String(
2069                                "Disabled",
2070                            ),
2071                        },
2072                    },
2073                ),
2074            ],
2075            trigger: Immediate,
2076            duration: 2592000s,
2077            was_rollup: false,
2078            match_internal: false,
2079        },
2080    ],
2081}
2082"#
2083        );
2084
2085        let example_com = shaping
2086            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2087            .await
2088            .finish()
2089            .unwrap();
2090        k9::snapshot!(
2091            example_com,
2092            r#"
2093MergedEntry {
2094    params: EgressPathConfig {
2095        connection_limit: 3,
2096        additional_connection_limits: {},
2097        enable_tls: Opportunistic,
2098        enable_mta_sts: true,
2099        enable_dane: false,
2100        enable_pipelining: true,
2101        enable_rset: true,
2102        tls_prefer_openssl: false,
2103        tls_certificate: None,
2104        tls_private_key: None,
2105        openssl_cipher_list: None,
2106        openssl_cipher_suites: None,
2107        openssl_options: None,
2108        rustls_cipher_suites: [],
2109        client_timeouts: SmtpClientTimeouts {
2110            connect_timeout: 60s,
2111            banner_timeout: 60s,
2112            ehlo_timeout: 300s,
2113            mail_from_timeout: 300s,
2114            rcpt_to_timeout: 300s,
2115            data_timeout: 30s,
2116            data_dot_timeout: 60s,
2117            rset_timeout: 5s,
2118            idle_timeout: 60s,
2119            starttls_timeout: 5s,
2120            auth_timeout: 60s,
2121        },
2122        system_shutdown_timeout: None,
2123        max_ready: 1024,
2124        consecutive_connection_failures_before_delay: 100,
2125        smtp_port: 25,
2126        smtp_auth_plain_username: None,
2127        smtp_auth_plain_password: None,
2128        allow_smtp_auth_plain_without_tls: false,
2129        max_message_rate: Some(
2130            100/s,
2131        ),
2132        additional_message_rate_throttles: {},
2133        source_selection_rate: None,
2134        additional_source_selection_rates: {},
2135        max_connection_rate: Some(
2136            100/m,
2137        ),
2138        max_deliveries_per_connection: 100,
2139        max_recipients_per_batch: 100,
2140        prohibited_hosts: {
2141            "0.0.0.0",
2142            "127.0.0.0/8",
2143            "::/127",
2144        },
2145        skip_hosts: {},
2146        ip_lookup_strategy: Ipv4AndIpv6,
2147        ehlo_domain: None,
2148        aggressive_connection_opening: false,
2149        refresh_interval: 60s,
2150        refresh_strategy: Ttl,
2151        dispatcher_wakeup_strategy: Aggressive,
2152        maintainer_wakeup_strategy: Aggressive,
2153        provider_name: None,
2154        remember_broken_tls: None,
2155        opportunistic_tls_reconnect_on_failed_handshake: false,
2156        use_lmtp: false,
2157        reconnect_strategy: ConnectNextHost,
2158        readyq_pool_name: None,
2159        low_memory_reduction_policy: ShrinkDataAndMeta,
2160        no_memory_reduction_policy: ShrinkDataAndMeta,
2161        try_next_host_on_transport_error: false,
2162        ignore_8bit_checks: false,
2163    },
2164    sources: {
2165        "my source name": EgressPathConfig {
2166            connection_limit: 5,
2167            additional_connection_limits: {},
2168            enable_tls: Opportunistic,
2169            enable_mta_sts: true,
2170            enable_dane: false,
2171            enable_pipelining: true,
2172            enable_rset: true,
2173            tls_prefer_openssl: false,
2174            tls_certificate: None,
2175            tls_private_key: None,
2176            openssl_cipher_list: None,
2177            openssl_cipher_suites: None,
2178            openssl_options: None,
2179            rustls_cipher_suites: [],
2180            client_timeouts: SmtpClientTimeouts {
2181                connect_timeout: 60s,
2182                banner_timeout: 60s,
2183                ehlo_timeout: 300s,
2184                mail_from_timeout: 300s,
2185                rcpt_to_timeout: 300s,
2186                data_timeout: 300s,
2187                data_dot_timeout: 300s,
2188                rset_timeout: 5s,
2189                idle_timeout: 5s,
2190                starttls_timeout: 5s,
2191                auth_timeout: 60s,
2192            },
2193            system_shutdown_timeout: None,
2194            max_ready: 1024,
2195            consecutive_connection_failures_before_delay: 100,
2196            smtp_port: 25,
2197            smtp_auth_plain_username: None,
2198            smtp_auth_plain_password: None,
2199            allow_smtp_auth_plain_without_tls: false,
2200            max_message_rate: None,
2201            additional_message_rate_throttles: {},
2202            source_selection_rate: None,
2203            additional_source_selection_rates: {},
2204            max_connection_rate: None,
2205            max_deliveries_per_connection: 1024,
2206            max_recipients_per_batch: 100,
2207            prohibited_hosts: {
2208                "0.0.0.0",
2209                "127.0.0.0/8",
2210                "::/127",
2211            },
2212            skip_hosts: {},
2213            ip_lookup_strategy: Ipv4AndIpv6,
2214            ehlo_domain: None,
2215            aggressive_connection_opening: false,
2216            refresh_interval: 60s,
2217            refresh_strategy: Ttl,
2218            dispatcher_wakeup_strategy: Aggressive,
2219            maintainer_wakeup_strategy: Aggressive,
2220            provider_name: None,
2221            remember_broken_tls: None,
2222            opportunistic_tls_reconnect_on_failed_handshake: false,
2223            use_lmtp: false,
2224            reconnect_strategy: ConnectNextHost,
2225            readyq_pool_name: None,
2226            low_memory_reduction_policy: ShrinkDataAndMeta,
2227            no_memory_reduction_policy: ShrinkDataAndMeta,
2228            try_next_host_on_transport_error: false,
2229            ignore_8bit_checks: false,
2230        },
2231    },
2232    automation: [
2233        Rule {
2234            regex: [
2235                Regex(
2236                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2237                ),
2238                Regex(
2239                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2240                ),
2241                Regex(
2242                    /has been temporarily rate limited due to IP reputation/,
2243                ),
2244                Regex(
2245                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2246                ),
2247                Regex(
2248                    /Server busy\. Please try again later from/,
2249                ),
2250            ],
2251            action: [
2252                SetConfig(
2253                    EgressPathConfigValue {
2254                        name: "max_message_rate",
2255                        value: HashableTomlValue {
2256                            value: String(
2257                                "1/minute",
2258                            ),
2259                        },
2260                    },
2261                ),
2262                SetConfig(
2263                    EgressPathConfigValue {
2264                        name: "connection_limit",
2265                        value: HashableTomlValue {
2266                            value: Integer(
2267                                1,
2268                            ),
2269                        },
2270                    },
2271                ),
2272            ],
2273            trigger: Immediate,
2274            duration: 5400s,
2275            was_rollup: false,
2276            match_internal: false,
2277        },
2278        Rule {
2279            regex: [
2280                Regex(
2281                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2282                ),
2283            ],
2284            action: [
2285                SetConfig(
2286                    EgressPathConfigValue {
2287                        name: "enable_tls",
2288                        value: HashableTomlValue {
2289                            value: String(
2290                                "Disabled",
2291                            ),
2292                        },
2293                    },
2294                ),
2295            ],
2296            trigger: Immediate,
2297            duration: 2592000s,
2298            was_rollup: false,
2299            match_internal: false,
2300        },
2301    ],
2302}
2303"#
2304        );
2305
2306        // The site name here will need to be updated if yahoo changes
2307        // their MX records
2308        let yahoo_com = shaping
2309            .get_egress_path_config(
2310                "yahoo.com",
2311                "invalid.source",
2312                "(mta5|mta6|mta7).am0.yahoodns.net",
2313            )
2314            .await
2315            .finish()
2316            .unwrap();
2317        k9::snapshot!(
2318            yahoo_com,
2319            r#"
2320MergedEntry {
2321    params: EgressPathConfig {
2322        connection_limit: 10,
2323        additional_connection_limits: {},
2324        enable_tls: Opportunistic,
2325        enable_mta_sts: true,
2326        enable_dane: false,
2327        enable_pipelining: true,
2328        enable_rset: true,
2329        tls_prefer_openssl: false,
2330        tls_certificate: None,
2331        tls_private_key: None,
2332        openssl_cipher_list: None,
2333        openssl_cipher_suites: None,
2334        openssl_options: None,
2335        rustls_cipher_suites: [],
2336        client_timeouts: SmtpClientTimeouts {
2337            connect_timeout: 60s,
2338            banner_timeout: 60s,
2339            ehlo_timeout: 300s,
2340            mail_from_timeout: 300s,
2341            rcpt_to_timeout: 300s,
2342            data_timeout: 30s,
2343            data_dot_timeout: 60s,
2344            rset_timeout: 5s,
2345            idle_timeout: 60s,
2346            starttls_timeout: 5s,
2347            auth_timeout: 60s,
2348        },
2349        system_shutdown_timeout: None,
2350        max_ready: 1024,
2351        consecutive_connection_failures_before_delay: 100,
2352        smtp_port: 25,
2353        smtp_auth_plain_username: None,
2354        smtp_auth_plain_password: None,
2355        allow_smtp_auth_plain_without_tls: false,
2356        max_message_rate: Some(
2357            100/s,
2358        ),
2359        additional_message_rate_throttles: {},
2360        source_selection_rate: None,
2361        additional_source_selection_rates: {},
2362        max_connection_rate: Some(
2363            100/m,
2364        ),
2365        max_deliveries_per_connection: 20,
2366        max_recipients_per_batch: 100,
2367        prohibited_hosts: {
2368            "0.0.0.0",
2369            "127.0.0.0/8",
2370            "::/127",
2371        },
2372        skip_hosts: {},
2373        ip_lookup_strategy: Ipv4AndIpv6,
2374        ehlo_domain: None,
2375        aggressive_connection_opening: false,
2376        refresh_interval: 60s,
2377        refresh_strategy: Ttl,
2378        dispatcher_wakeup_strategy: Aggressive,
2379        maintainer_wakeup_strategy: Aggressive,
2380        provider_name: None,
2381        remember_broken_tls: None,
2382        opportunistic_tls_reconnect_on_failed_handshake: false,
2383        use_lmtp: false,
2384        reconnect_strategy: ConnectNextHost,
2385        readyq_pool_name: None,
2386        low_memory_reduction_policy: ShrinkDataAndMeta,
2387        no_memory_reduction_policy: ShrinkDataAndMeta,
2388        try_next_host_on_transport_error: false,
2389        ignore_8bit_checks: false,
2390    },
2391    sources: {},
2392    automation: [
2393        Rule {
2394            regex: [
2395                Regex(
2396                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2397                ),
2398                Regex(
2399                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2400                ),
2401                Regex(
2402                    /has been temporarily rate limited due to IP reputation/,
2403                ),
2404                Regex(
2405                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2406                ),
2407                Regex(
2408                    /Server busy\. Please try again later from/,
2409                ),
2410            ],
2411            action: [
2412                SetConfig(
2413                    EgressPathConfigValue {
2414                        name: "max_message_rate",
2415                        value: HashableTomlValue {
2416                            value: String(
2417                                "1/minute",
2418                            ),
2419                        },
2420                    },
2421                ),
2422                SetConfig(
2423                    EgressPathConfigValue {
2424                        name: "connection_limit",
2425                        value: HashableTomlValue {
2426                            value: Integer(
2427                                1,
2428                            ),
2429                        },
2430                    },
2431                ),
2432            ],
2433            trigger: Immediate,
2434            duration: 5400s,
2435            was_rollup: false,
2436            match_internal: false,
2437        },
2438        Rule {
2439            regex: [
2440                Regex(
2441                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2442                ),
2443            ],
2444            action: [
2445                SetConfig(
2446                    EgressPathConfigValue {
2447                        name: "enable_tls",
2448                        value: HashableTomlValue {
2449                            value: String(
2450                                "Disabled",
2451                            ),
2452                        },
2453                    },
2454                ),
2455            ],
2456            trigger: Immediate,
2457            duration: 2592000s,
2458            was_rollup: false,
2459            match_internal: false,
2460        },
2461        Rule {
2462            regex: [
2463                Regex(
2464                    \[TS04\],
2465                ),
2466            ],
2467            action: [
2468                Suspend,
2469            ],
2470            trigger: Immediate,
2471            duration: 7200s,
2472            was_rollup: false,
2473            match_internal: false,
2474        },
2475    ],
2476}
2477"#
2478        );
2479    }
2480
2481    #[tokio::test]
2482    async fn test_load_default_shaping_toml() {
2483        Shaping::merge_files(
2484            &["../../assets/policy-extras/shaping.toml".into()],
2485            &ShapingMergeOptions::default(),
2486        )
2487        .await
2488        .unwrap();
2489    }
2490}