kumo_api_types/
shaping.rs

1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39    type Error = fancy_regex::Error;
40
41    fn try_from(s: String) -> fancy_regex::Result<Self> {
42        Ok(Self(fancy_regex::Regex::new(&s)?))
43    }
44}
45
46impl From<Regex> for String {
47    fn from(r: Regex) -> String {
48        r.0.as_str().to_string()
49    }
50}
51
52impl std::ops::Deref for Regex {
53    type Target = fancy_regex::Regex;
54    fn deref(&self) -> &Self::Target {
55        &self.0
56    }
57}
58
59impl std::hash::Hash for Regex {
60    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61        self.0.as_str().hash(hasher)
62    }
63}
64
65/// toml::Value is not Hash because it may contain floating
66/// point numbers, which are problematic from a Ord and Eq
67/// perspective. We're okay with skirting around that for
68/// our purposes here, so we implement our own hashable
69/// wrapper around the toml value.
70#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73    value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77    fn from(value: toml::Value) -> Self {
78        Self { value }
79    }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83    fn from(value: HashableTomlValue) -> toml::Value {
84        value.value
85    }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89    type Target = toml::Value;
90    fn deref(&self) -> &toml::Value {
91        &self.value
92    }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97    H: Hasher,
98{
99    match value {
100        toml::Value::Boolean(v) => v.hash(h),
101        toml::Value::Datetime(v) => {
102            if let Some(d) = &v.date {
103                d.year.hash(h);
104                d.month.hash(h);
105                d.day.hash(h);
106            }
107            if let Some(t) = &v.time {
108                t.hour.hash(h);
109                t.minute.hash(h);
110                t.second.hash(h);
111                t.nanosecond.hash(h);
112            }
113            if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114                minutes.hash(h);
115            }
116        }
117        toml::Value::String(v) => v.hash(h),
118        toml::Value::Integer(v) => v.hash(h),
119        toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120        toml::Value::Array(a) => {
121            for v in a.iter() {
122                hash_toml(v, h);
123            }
124        }
125        toml::Value::Table(m) => {
126            for (k, v) in m.iter() {
127                k.hash(h);
128                hash_toml(v, h);
129            }
130        }
131    }
132}
133
134impl Hash for HashableTomlValue {
135    fn hash<H>(&self, h: &mut H)
136    where
137        H: Hasher,
138    {
139        hash_toml(&self.value, h);
140    }
141}
142
143/// Represents an individual EgressPathConfig field name and value.
144/// It only allows deserializing from valid EgressPathConfig field + values.
145#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147    try_from = "EgressPathConfigValueUnchecked",
148    into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151    pub name: String,
152    pub value: HashableTomlValue,
153}
154
155/// This is the type that we actually use to deserialize EgressPathConfigValue items.
156/// It doesn't care about validity; it is used solely to tell serde what shape of
157/// data to expect.
158/// The validation is performed by the TryFrom impl that is used to convert to the
159/// checked form below.
160#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162    pub name: String,
163    pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167    type Error = anyhow::Error;
168    fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169        let mut map = toml::map::Map::new();
170        map.insert(config.name.clone(), config.value.clone());
171        let table = toml::Value::Table(map);
172
173        // Attempt to deserialize as EgressPathConfig.
174        // If it fails, then the field name/value are invalid
175        EgressPathConfig::deserialize(table)?;
176
177        // If we reach this point, we can pass along the name/value
178        Ok(EgressPathConfigValue {
179            name: config.name,
180            value: HashableTomlValue {
181                value: config.value,
182            },
183        })
184    }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188    fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189        EgressPathConfigValueUnchecked {
190            name: config.name,
191            value: config.value.value,
192        }
193    }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198    Suspend,
199    SetConfig(EgressPathConfigValue),
200    SuspendTenant,
201    SuspendCampaign,
202    SetDomainConfig(EgressPathConfigValue),
203    Bounce,
204    BounceTenant,
205    BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210    /// Trigger on the first match, immediately
211    #[default]
212    Immediate,
213    /// Trigger when a certain number of matches occur
214    /// over a certain time period.
215    Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221    #[serde(deserialize_with = "regex_string_or_array")]
222    pub regex: Vec<Regex>,
223
224    #[serde(deserialize_with = "one_or_many_action")]
225    pub action: Vec<Action>,
226
227    #[serde(default)]
228    pub trigger: Trigger,
229
230    #[serde(with = "duration_serde")]
231    pub duration: Duration,
232
233    #[serde(skip)]
234    pub was_rollup: bool,
235
236    /// if true, this rule can match kumomta internally generated
237    /// Response messages, otherwise, the rule will skip testing
238    /// against those.
239    #[serde(default)]
240    pub match_internal: bool,
241}
242
243impl Rule {
244    pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245        if is_internal && !self.match_internal {
246            return false;
247        }
248        self.regex
249            .iter()
250            .any(|r| r.is_match(response).unwrap_or(false))
251    }
252
253    pub fn clone_and_set_rollup(&self) -> Self {
254        let mut result = self.clone();
255        result.was_rollup = true;
256        result
257    }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263    by_site: OrderMap<String, PartialEntry>,
264    by_domain: OrderMap<String, PartialEntry>,
265    by_provider: OrderMap<String, ProviderEntry>,
266    warnings: Vec<String>,
267    errors: Vec<String>,
268    hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273    pub async fn get_egress_path_config(
274        &self,
275        domain: &str,
276        egress_source: &str,
277        site_name: &str,
278    ) -> PartialEntry {
279        let mut params = PartialEntry::default();
280
281        // Apply basic/default configuration
282        if let Some(default) = self.by_domain.get("default") {
283            params.merge_from(default.clone());
284        }
285
286        // Provider rules come next.
287        // They can only match valid domain names, so we'll
288        // skip processing them if we have something like our
289        // TSA http "domain name" here
290        let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291        if is_domain_name {
292            let mut prov_with_sources = vec![];
293
294            for prov in self.by_provider.values() {
295                if prov.domain_matches(domain).await {
296                    toml_table_merge_from(&mut params.params, &prov.params);
297                    prov.apply_provider_params_to(egress_source, &mut params.params);
298
299                    if !prov.sources.is_empty() {
300                        // Remember this matching provider, so that we
301                        // can apply any source rules after we've applied
302                        // any/all base provider rules for this domain
303                        prov_with_sources.push(prov);
304                    }
305                }
306            }
307
308            // Then Provider source rules
309            for prov in prov_with_sources {
310                if let Some(source) = prov.sources.get(egress_source) {
311                    toml_table_merge_from(&mut params.params, source);
312                    prov.apply_provider_params_to(egress_source, &mut params.params);
313                }
314            }
315        }
316
317        // Then site config
318        if let Some(by_site) = self.by_site.get(site_name) {
319            params.merge_from(by_site.clone());
320        }
321
322        // Then domain config
323        if let Some(by_domain) = self.by_domain.get(domain) {
324            params.merge_from(by_domain.clone());
325        }
326
327        // Then source config for the site
328        if let Some(by_site) = self.by_site.get(site_name) {
329            if let Some(source) = by_site.sources.get(egress_source) {
330                toml_table_merge_from(&mut params.params, source);
331            }
332        }
333
334        // Then source config for the domain
335        if let Some(by_domain) = self.by_domain.get(domain) {
336            if let Some(source) = by_domain.sources.get(egress_source) {
337                toml_table_merge_from(&mut params.params, source);
338            }
339        }
340
341        params
342    }
343
344    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345        use rfc5321::ForwardPath;
346        // Extract the domain from the recipient.
347        let recipient = ForwardPath::try_from(
348            record
349                .recipient
350                .first()
351                .ok_or_else(|| anyhow::anyhow!("no recipients!?"))?
352                .as_str(),
353        )
354        .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
355
356        let recipient = match recipient {
357            ForwardPath::Postmaster => {
358                // It doesn't make sense to apply automation on the
359                // local postmaster address, so we ignore this.
360                return Ok(vec![]);
361            }
362            ForwardPath::Path(path) => path.mailbox,
363        };
364        let domain = recipient.domain.to_string();
365
366        // Track events/outcomes by site.
367        let source = record.egress_source.as_deref().unwrap_or("unspecified");
368        // record.site is poorly named; it is really an identifier for the
369        // egress path. For matching purposes, we want just the site_name
370        // in the form produced by our MX resolution process.
371        // In an earlier incarnation of this logic, we would resolve the
372        // site_name for ourselves based on other data in the record,
373        // but that could lead to over-resolution of some names and
374        // yield surprising results.
375        // What we do here is extract the egress path decoration from
376        // record.site to arrive at something that looks like the
377        // mx site_name.
378        // NOTE: this is coupled with the logic in
379        // ReadyQueueManager::compute_queue_name
380        let site_name = record
381            .site
382            .trim_start_matches(&format!("{source}->"))
383            .trim_end_matches("@smtp_client")
384            .to_string();
385
386        Ok(self.match_rules_impl(record, &domain, &site_name).await)
387    }
388
389    pub async fn match_rules_impl(
390        &self,
391        record: &JsonLogRecord,
392        domain: &str,
393        site_name: &str,
394    ) -> Vec<Rule> {
395        let mut result = vec![];
396        let response = record.response.to_single_line();
397        tracing::trace!("Consider rules for {response}");
398
399        let is_internal = record.response.content.starts_with("KumoMTA internal: ");
400
401        if let Some(default) = self.by_domain.get("default") {
402            for rule in &default.automation {
403                tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
404                if rule.matches(is_internal, &response) {
405                    // For automation under `default`, we always
406                    // assume that mx_rollup should be true.
407                    // If you somehow have a domain where that isn't
408                    // true, you should avoid using `default` for
409                    // automation.  Honestly, it's best to avoid
410                    // using `default` for automation.
411                    result.push(rule.clone_and_set_rollup());
412                }
413            }
414        }
415
416        for prov in self.by_provider.values() {
417            if prov.domain_matches(domain).await {
418                for rule in &prov.automation {
419                    tracing::trace!(
420                        "Consider provider \"{}\" rule {rule:?} for {response}",
421                        prov.provider_name
422                    );
423                    if rule.matches(is_internal, &response) {
424                        result.push(rule.clone());
425                    }
426                }
427            }
428        }
429
430        // Then site config
431        if let Some(by_site) = self.by_site.get(site_name) {
432            for rule in &by_site.automation {
433                tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
434                if rule.matches(is_internal, &response) {
435                    result.push(rule.clone_and_set_rollup());
436                }
437            }
438        }
439
440        // Then domain config
441        if let Some(by_domain) = self.by_domain.get(domain) {
442            for rule in &by_domain.automation {
443                tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
444                if rule.matches(is_internal, &response) {
445                    result.push(rule.clone());
446                }
447            }
448        }
449
450        result
451    }
452}
453
454#[cfg(feature = "lua")]
455#[derive(Debug, Default, Clone, mlua::FromLua)]
456pub struct Shaping {
457    inner: Arc<ShapingInner>,
458}
459
460#[cfg(feature = "lua")]
461fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
462    let d = &mut serde_json::Deserializer::from_str(json);
463    Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
468    let d = toml::Deserializer::parse(toml)?;
469    Ok(serde_path_to_error::deserialize(d)?)
470}
471
472#[cfg(feature = "lua")]
473#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
474pub enum CheckLevel {
475    #[default]
476    Ignore,
477    Warn,
478    Error,
479}
480#[cfg(feature = "lua")]
481impl std::str::FromStr for CheckLevel {
482    type Err = String;
483
484    fn from_str(s: &str) -> Result<Self, String> {
485        if s.eq_ignore_ascii_case("ignore") {
486            Ok(Self::Ignore)
487        } else if s.eq_ignore_ascii_case("warn") {
488            Ok(Self::Warn)
489        } else if s.eq_ignore_ascii_case("error") {
490            Ok(Self::Error)
491        } else {
492            Err(format!(
493                "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
494            ))
495        }
496    }
497}
498
499#[cfg(feature = "lua")]
500#[derive(Default)]
501struct Collector {
502    warnings: Vec<String>,
503    errors: Vec<String>,
504}
505
506#[cfg(feature = "lua")]
507impl Collector {
508    fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
509        match level {
510            CheckLevel::Ignore => {}
511            CheckLevel::Warn => self.warnings.push(msg.into()),
512            CheckLevel::Error => self.errors.push(msg.into()),
513        }
514    }
515}
516
517#[cfg(feature = "lua")]
518#[derive(Debug, Clone, Deserialize)]
519#[serde_as]
520#[serde(deny_unknown_fields)]
521pub struct ShapingMergeOptions {
522    #[serde(default)]
523    pub provider_overlap: CheckLevel,
524    #[serde(default)]
525    pub dns_fail: CheckLevel,
526    #[serde(default)]
527    pub null_mx: CheckLevel,
528    #[serde(default)]
529    pub aliased_site: CheckLevel,
530    #[serde(default)]
531    pub skip_remote: bool,
532    #[serde(default)]
533    pub remote_load: CheckLevel,
534    #[serde(default)]
535    pub local_load: CheckLevel,
536    #[serde(default, with = "duration_serde")]
537    pub http_timeout: Option<Duration>,
538}
539
540#[cfg(feature = "lua")]
541impl Default for ShapingMergeOptions {
542    fn default() -> Self {
543        Self {
544            provider_overlap: CheckLevel::Ignore,
545            dns_fail: CheckLevel::Ignore,
546            null_mx: CheckLevel::Ignore,
547            aliased_site: CheckLevel::Ignore,
548            skip_remote: false,
549            remote_load: CheckLevel::Ignore,
550            local_load: CheckLevel::Error,
551            http_timeout: None,
552        }
553    }
554}
555
556#[cfg(feature = "lua")]
557impl Shaping {
558    async fn load_from_file(
559        path: &str,
560        options: &ShapingMergeOptions,
561        collector: &mut Collector,
562    ) -> anyhow::Result<ShapingFile> {
563        let (data, level): (String, CheckLevel) =
564            if path.starts_with("http://") || path.starts_with("https://") {
565                if options.skip_remote {
566                    collector.push(
567                        CheckLevel::Warn,
568                        format!("Ignoring {path} because skip_remote is set to true"),
569                    );
570                    return Ok(ShapingFile::default());
571                }
572
573                // To facilitate startup ordering races, and listing multiple subscription
574                // host replicas and allowing one or more of them to be temporarily down,
575                // we allow the http request to fail.
576                // We'll log the error message but consider it to be an empty map
577
578                async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
579                    tokio::time::timeout(timeout, async {
580                        reqwest::Client::builder()
581                            .timeout(timeout)
582                            .connect_timeout(timeout)
583                            .read_timeout(timeout)
584                            .build()?
585                            .get(url)
586                            .send()
587                            .await
588                            .with_context(|| format!("making HTTP request to {url}"))?
589                            .text()
590                            .await
591                            .with_context(|| format!("reading text from {url}"))
592                    })
593                    .await
594                    .with_context(|| format!("timeout making HTTP request to {url}"))?
595                }
596
597                let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
598
599                match http_get(path, timeout).await {
600                    Ok(s) => (s, options.remote_load),
601                    Err(err) => {
602                        tracing::error!("{err:#}. Ignoring this shaping source for now");
603                        collector.push(
604                            options.remote_load,
605                            format!("remote shaping source {path} error: {err:#}"),
606                        );
607                        return Ok(ShapingFile::default());
608                    }
609                }
610            } else {
611                match std::fs::read_to_string(path)
612                    .with_context(|| format!("loading data from file {path}"))
613                {
614                    Err(err) => {
615                        collector.push(
616                            options.local_load,
617                            format!("local shaping source {path} error: {err:#}"),
618                        );
619                        return Ok(ShapingFile::default());
620                    }
621                    Ok(s) => (s, options.local_load),
622                }
623            };
624
625        if path.ends_with(".toml") {
626            from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
627        } else if path.ends_with(".json") {
628            from_json(&data).with_context(|| format!("parsing json from file {path}"))
629        } else {
630            // Try parsing both ways and see which wins
631            let mut errors = vec![];
632            match from_toml(&data) {
633                Ok(s) => return Ok(s),
634                Err(err) => errors.push(format!("as toml: {err:#}")),
635            }
636            match from_json(&data) {
637                Ok(s) => return Ok(s),
638                Err(err) => errors.push(format!("as json: {err:#}")),
639            }
640
641            collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
642            Ok(ShapingFile::default())
643        }
644    }
645
646    pub async fn merge_files(
647        files: &[String],
648        options: &ShapingMergeOptions,
649    ) -> anyhow::Result<Self> {
650        use futures_util::stream::FuturesUnordered;
651        use futures_util::StreamExt;
652
653        let mut collector = Collector::default();
654        let mut loaded = vec![];
655        for p in files {
656            loaded.push(Self::load_from_file(p, options, &mut collector).await?);
657        }
658
659        let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
660        let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
661        let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
662        let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
663
664        // Pre-resolve domains. We don't interleave the resolution with
665        // the work below, because we want to ensure that the ordering
666        // is preserved
667        let mut mx = std::collections::HashMap::new();
668        let mut lookups = FuturesUnordered::new();
669        for item in &loaded {
670            for (domain, partial) in &item.domains {
671                if partial.mx_rollup {
672                    let domain = domain.to_string();
673                    lookups.push(tokio::spawn(async move {
674                        let mx_result = MailExchanger::resolve(&domain).await;
675                        (domain, mx_result)
676                    }));
677                }
678            }
679        }
680
681        while let Some(Ok((domain, result))) = lookups.next().await {
682            mx.insert(domain, result);
683        }
684
685        for mut item in loaded {
686            if let Some(mut partial) = item.default.take() {
687                let domain = "default";
688                partial.domain_name.replace(domain.to_string());
689                match by_domain.get_mut(domain) {
690                    Some(existing) => {
691                        existing.merge_from(partial);
692                    }
693                    None => {
694                        by_domain.insert(domain.to_string(), partial);
695                    }
696                }
697            }
698
699            for (domain, mut partial) in item.domains {
700                partial.domain_name.replace(domain.clone());
701
702                if let Ok(name) = fully_qualify(&domain) {
703                    if name.num_labels() == 1 {
704                        collector.push(
705                            CheckLevel::Warn,
706                            format!(
707                                "Entry for domain '{domain}' consists of a \
708                                 single DNS label. Domain names in TOML sections \
709                                 need to be quoted like '[\"{domain}.com\"]` otherwise \
710                                 the '.' will create a nested table rather than being \
711                                 added to the domain name."
712                            ),
713                        );
714                    }
715                }
716
717                #[cfg(test)]
718                if partial._treat_domain_name_as_site_name {
719                    match by_site.get_mut(&domain) {
720                        Some(existing) => {
721                            existing.merge_from(partial);
722                        }
723                        None => {
724                            by_site.insert(domain.to_string(), partial);
725                        }
726                    }
727                    continue;
728                }
729                if partial.mx_rollup {
730                    let mx = match mx.get(&domain) {
731                        Some(Ok(mx)) => mx,
732                        Some(Err(err)) => {
733                            collector.push(
734                                options.dns_fail,
735                                format!(
736                                    "error resolving MX for {domain}: {err:#}. \
737                                 Ignoring the shaping config for that domain."
738                                ),
739                            );
740                            continue;
741                        }
742                        None => {
743                            collector.push(
744                                options.dns_fail,
745                                format!(
746                                "We didn't try to resolve the MX for {domain} for some reason!?. \
747                                 Ignoring the shaping config for that domain."
748                            ),
749                            );
750                            continue;
751                        }
752                    };
753
754                    if mx.site_name.is_empty() {
755                        collector.push(
756                            options.null_mx,
757                            format!(
758                            "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
759                             Ignoring the shaping config for that domain."),
760                        );
761                        continue;
762                    }
763
764                    site_aliases
765                        .entry(mx.site_name.to_string())
766                        .or_default()
767                        .push(domain.to_string());
768
769                    match by_site.get_mut(&mx.site_name) {
770                        Some(existing) => {
771                            existing.merge_from(partial);
772                        }
773                        None => {
774                            by_site.insert(mx.site_name.clone(), partial);
775                        }
776                    }
777                } else {
778                    match by_domain.get_mut(&domain) {
779                        Some(existing) => {
780                            existing.merge_from(partial);
781                        }
782                        None => {
783                            by_domain.insert(domain, partial);
784                        }
785                    }
786                }
787            }
788
789            for (provider, mut prov) in item.provider {
790                prov.provider_name = provider.to_string();
791                match by_provider.get_mut(&provider) {
792                    Some(existing) => {
793                        existing.merge_from(prov);
794                    }
795                    None => {
796                        by_provider.insert(provider.to_string(), prov);
797                    }
798                }
799            }
800        }
801
802        for (site, partial) in &by_site {
803            partial
804                .clone()
805                .finish()
806                .with_context(|| format!("site: {site}"))?;
807        }
808
809        for (domain, partial) in &by_domain {
810            partial
811                .clone()
812                .finish()
813                .with_context(|| format!("domain: {domain}"))?;
814        }
815
816        for (provider, prov) in &by_provider {
817            prov.finish_params()
818                .with_context(|| format!("provider: {provider}"))?;
819        }
820
821        if options.aliased_site != CheckLevel::Ignore {
822            for (site, aliases) in site_aliases {
823                if aliases.len() > 1 {
824                    collector.push(
825                        options.aliased_site,
826                        format!(
827                            "multiple domain blocks alias to the same site: {site}: {}",
828                            aliases.join(", ")
829                        ),
830                    );
831                }
832            }
833        }
834
835        if options.provider_overlap != CheckLevel::Ignore {
836            for domain in mx.keys() {
837                let mut matching_providers = vec![];
838                for (prov_name, prov) in &by_provider {
839                    if prov.domain_matches(domain).await {
840                        matching_providers.push(prov_name.to_string());
841                    }
842                }
843                if !matching_providers.is_empty() {
844                    collector.push(
845                        options.provider_overlap,
846                        format!(
847                            "domain {domain} is also matched by provider(s): {}",
848                            matching_providers.join(", ")
849                        ),
850                    );
851                }
852            }
853        }
854
855        let mut ctx = Sha256::new();
856        ctx.update("by_site");
857        for (site, entry) in &by_site {
858            ctx.update(site);
859            entry.hash_into(&mut ctx);
860        }
861        ctx.update("by_domain");
862        for (domain, entry) in &by_domain {
863            ctx.update(domain);
864            entry.hash_into(&mut ctx);
865        }
866        ctx.update("by_provider");
867        for (provider, prov) in &by_provider {
868            ctx.update(provider);
869            prov.hash_into(&mut ctx);
870        }
871        ctx.update("warnings");
872        for warn in &collector.warnings {
873            ctx.update(warn);
874        }
875        ctx.update("errors");
876        for err in &collector.errors {
877            ctx.update(err);
878        }
879        let hash = ctx.finalize();
880        let hash = data_encoding::HEXLOWER.encode(&hash);
881
882        Ok(Self {
883            inner: Arc::new(ShapingInner {
884                by_site,
885                by_domain,
886                by_provider,
887                warnings: collector.warnings,
888                errors: collector.errors,
889                hash,
890            }),
891        })
892    }
893
894    async fn get_egress_path_config(
895        &self,
896        domain: &str,
897        egress_source: &str,
898        site_name: &str,
899    ) -> PartialEntry {
900        self.inner
901            .get_egress_path_config(domain, egress_source, site_name)
902            .await
903    }
904
905    pub async fn get_egress_path_config_value(
906        &self,
907        domain: &str,
908        egress_source: &str,
909        site_name: &str,
910    ) -> anyhow::Result<serde_json::Value> {
911        let partial = self
912            .get_egress_path_config(domain, egress_source, site_name)
913            .await;
914        Ok(serde_json::to_value(&partial)?)
915    }
916
917    pub fn get_errors(&self) -> &[String] {
918        &self.inner.errors
919    }
920
921    pub fn get_warnings(&self) -> &[String] {
922        &self.inner.warnings
923    }
924
925    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
926        self.inner.match_rules(record).await
927    }
928
929    pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
930        let mut result = BTreeMap::new();
931
932        for (site_name, site) in &self.inner.by_site {
933            for source_name in site.sources.keys() {
934                result
935                    .entry(source_name.to_string())
936                    .or_insert(vec![])
937                    .push(format!("site:{site_name}"));
938            }
939        }
940        for (domain_name, domain) in &self.inner.by_domain {
941            for source_name in domain.sources.keys() {
942                result
943                    .entry(source_name.to_string())
944                    .or_insert(vec![])
945                    .push(format!("domain:{domain_name}"));
946            }
947        }
948
949        result
950    }
951
952    pub fn hash(&self) -> String {
953        self.inner.hash.clone()
954    }
955}
956
957#[cfg(feature = "lua")]
958impl LuaUserData for Shaping {
959    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
960        mod_memoize::Memoized::impl_memoize(methods);
961        methods.add_async_method(
962            "get_egress_path_config",
963            |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
964                let params = this
965                    .get_egress_path_config(&domain, &egress_source, &site_name)
966                    .await;
967                lua.to_value_with(&params.params, serialize_options())
968            },
969        );
970
971        methods.add_method("get_errors", move |_lua, this, ()| {
972            let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
973            Ok(errors)
974        });
975
976        methods.add_method("get_warnings", move |_lua, this, ()| {
977            let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
978            Ok(warnings)
979        });
980
981        methods.add_method("get_referenced_sources", move |_lua, this, ()| {
982            Ok(this.get_referenced_sources())
983        });
984
985        methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
986            let record: JsonLogRecord = lua.from_value(record)?;
987            let rules = this.match_rules(&record).await.map_err(any_err)?;
988            let mut result = vec![];
989            for rule in rules {
990                result.push(lua.to_value(&rule)?);
991            }
992            Ok(result)
993        });
994
995        methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
996    }
997}
998
999#[derive(Default, Debug)]
1000pub struct MergedEntry {
1001    pub params: EgressPathConfig,
1002    pub sources: OrderMap<String, EgressPathConfig>,
1003    pub automation: Vec<Rule>,
1004}
1005
1006#[cfg(feature = "lua")]
1007#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1008struct ShapingFile {
1009    pub default: Option<PartialEntry>,
1010    #[serde(flatten, default)]
1011    pub domains: OrderMap<String, PartialEntry>,
1012    #[serde(default)]
1013    pub provider: OrderMap<String, ProviderEntry>,
1014}
1015
1016#[cfg(feature = "lua")]
1017#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1018struct PartialEntry {
1019    #[serde(skip)]
1020    pub domain_name: Option<String>,
1021
1022    #[serde(flatten)]
1023    pub params: toml::Table,
1024
1025    #[serde(default = "default_true")]
1026    pub mx_rollup: bool,
1027
1028    // This is present to facilitate unit testing without requiring
1029    // DNS to resolve the site_name. When set to true, the domain_name
1030    // is considered to be the site_name for this entry.
1031    #[cfg(test)]
1032    #[serde(default)]
1033    pub _treat_domain_name_as_site_name: bool,
1034
1035    #[serde(default)]
1036    pub replace_base: bool,
1037
1038    #[serde(default)]
1039    pub automation: Vec<Rule>,
1040
1041    #[serde(default)]
1042    pub sources: OrderMap<String, toml::Table>,
1043}
1044
1045#[cfg(feature = "lua")]
1046#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1047pub struct ProviderEntry {
1048    #[serde(skip, default)]
1049    pub provider_name: String,
1050
1051    #[serde(default)]
1052    pub provider_connection_limit: Option<LimitSpec>,
1053
1054    #[serde(default)]
1055    pub provider_max_message_rate: Option<ThrottleSpec>,
1056
1057    #[serde(default, rename = "match")]
1058    pub matches: Vec<ProviderMatch>,
1059
1060    #[serde(default)]
1061    pub replace_base: bool,
1062
1063    #[serde(flatten)]
1064    pub params: toml::Table,
1065
1066    #[serde(default)]
1067    pub automation: Vec<Rule>,
1068
1069    #[serde(default)]
1070    pub sources: OrderMap<String, toml::Table>,
1071}
1072
1073#[cfg(feature = "lua")]
1074fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1075    // Remove trailing dot from candidate, as our resolver tends
1076    // to leave the canonical dot on the input host name
1077    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1078    // We DON'T handle case mismatches in this code.
1079    // We assume that the dns-resolver crate normalized the
1080    // names to lowercase and that the input suffix is lowercase as well.
1081    candidate.ends_with(suffix)
1082}
1083
1084#[cfg(feature = "lua")]
1085#[cfg(test)]
1086#[test]
1087fn test_suffix_matches() {
1088    assert!(suffix_matches("a", "a"));
1089    assert!(suffix_matches("foo.com", "foo.com"));
1090    assert!(!suffix_matches("foo.com", ".foo.com"));
1091    assert!(!suffix_matches("foo.com", "longer.com"));
1092    assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1093
1094    // We DON'T handle case mismatches in this code.
1095    // We assume that the dns-resolver crate normalized the
1096    // names to lowercase and that the input suffix is lowercase as well.
1097    assert!(!suffix_matches("foo.Com", ".com"));
1098    assert!(!suffix_matches("foo.Cam", ".com"));
1099}
1100
1101#[cfg(feature = "lua")]
1102fn host_matches(candidate: &str, name: &str) -> bool {
1103    // Remove trailing dot from candidate, as our resolver tends
1104    // to leave the canonical dot on the input host name
1105    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1106    candidate == name
1107}
1108
1109#[cfg(feature = "lua")]
1110#[cfg(test)]
1111#[test]
1112fn test_host_matches() {
1113    assert!(host_matches("foo.com", "foo.com"));
1114    assert!(host_matches("foo.com.", "foo.com"));
1115    assert!(!host_matches("foo.com", "notfoo.com"));
1116}
1117
1118#[cfg(feature = "lua")]
1119impl ProviderEntry {
1120    async fn domain_matches(&self, domain: &str) -> bool {
1121        // We'd like to avoid doing DNS if we can do a simple suffix match,
1122        // so we bias to looking at those first
1123        let mut need_mx = false;
1124
1125        tracing::trace!(
1126            "ProviderEntry::domain_matches({domain}) vs {:?}",
1127            self.matches
1128        );
1129
1130        for rule in &self.matches {
1131            match rule {
1132                ProviderMatch::DomainSuffix(suffix) => {
1133                    if suffix_matches(domain, suffix) {
1134                        tracing::trace!("{domain} suffix matches {suffix}");
1135                        return true;
1136                    }
1137                }
1138                ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1139                    need_mx = true;
1140                }
1141            }
1142        }
1143
1144        if !need_mx {
1145            return false;
1146        }
1147
1148        // Now we can consider DNS
1149        match MailExchanger::resolve(domain).await {
1150            Err(err) => {
1151                // Didn't resolve; could be legit, for example, could
1152                // be an internal or fake name that is handled via
1153                // smart host or custom routing, so we log what
1154                // happened as a trace rather than polluting the
1155                // logs about it
1156                tracing::trace!(
1157                    "Error resolving MX for {domain}: {err:#}. \
1158                    Provider {} match rules will be ignored",
1159                    self.provider_name
1160                );
1161                false
1162            }
1163            Ok(mx) => {
1164                tracing::trace!("Consider MXSuffix rules");
1165                for host in &mx.hosts {
1166                    let mut matched = false;
1167
1168                    for rule in &self.matches {
1169                        match rule {
1170                            ProviderMatch::MXSuffix(suffix) => {
1171                                // For a given MX suffix rule, all hosts must match
1172                                // it for it to be valid. This is so that we don't
1173                                // falsely lump a vanity domain that blends providers
1174                                // together.
1175                                tracing::trace!("suffix={suffix} vs host {host}");
1176                                if suffix_matches(host, suffix) {
1177                                    matched = true;
1178                                    break;
1179                                }
1180                            }
1181                            ProviderMatch::HostName(name) => {
1182                                if host_matches(host, name) {
1183                                    matched = true;
1184                                    break;
1185                                }
1186                            }
1187                            ProviderMatch::DomainSuffix(_) => {}
1188                        }
1189                    }
1190
1191                    if !matched {
1192                        tracing::trace!("host didn't match any of these rules");
1193                        return false;
1194                    }
1195                }
1196
1197                true
1198            }
1199        }
1200    }
1201
1202    fn merge_from(&mut self, mut other: Self) {
1203        if other.replace_base {
1204            self.provider_connection_limit = other.provider_connection_limit;
1205            self.matches = other.matches;
1206            self.params = other.params;
1207            self.sources = other.sources;
1208            self.automation = other.automation;
1209        } else {
1210            if other.provider_connection_limit.is_some() {
1211                self.provider_connection_limit = other.provider_connection_limit;
1212            }
1213
1214            toml_table_merge_from(&mut self.params, &other.params);
1215
1216            for (source, tbl) in other.sources {
1217                match self.sources.get_mut(&source) {
1218                    Some(existing) => {
1219                        toml_table_merge_from(existing, &tbl);
1220                    }
1221                    None => {
1222                        self.sources.insert(source, tbl);
1223                    }
1224                }
1225            }
1226
1227            self.matches.append(&mut other.matches);
1228            self.automation.append(&mut other.automation);
1229        }
1230    }
1231
1232    fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1233        let mut implied = toml::Table::new();
1234        implied.insert(
1235            "provider_name".to_string(),
1236            toml::Value::String(self.provider_name.to_string()),
1237        );
1238
1239        if let Some(limit) = &self.provider_connection_limit {
1240            let mut limits = toml::Table::new();
1241            limits.insert(
1242                format!("shaping-provider-{}-{source}-limit", self.provider_name),
1243                toml::Value::String(limit.to_string()),
1244            );
1245            implied.insert(
1246                "additional_connection_limits".to_string(),
1247                toml::Value::Table(limits),
1248            );
1249        }
1250        if let Some(rate) = &self.provider_max_message_rate {
1251            let rate = rate.as_string();
1252            let mut limits = toml::Table::new();
1253            limits.insert(
1254                format!("shaping-provider-{}-{source}-rate", self.provider_name),
1255                rate.into(),
1256            );
1257            implied.insert(
1258                "additional_message_rate_throttles".to_string(),
1259                toml::Value::Table(limits),
1260            );
1261        }
1262
1263        if let Some(rate) = target.remove("provider_source_selection_rate") {
1264            let mut limits = toml::Table::new();
1265            limits.insert(
1266                format!(
1267                    "shaping-provider-{}-{source}-selection-rate",
1268                    self.provider_name
1269                ),
1270                rate,
1271            );
1272            implied.insert(
1273                "additional_source_selection_rates".to_string(),
1274                toml::Value::Table(limits),
1275            );
1276        }
1277
1278        toml_table_merge_from(target, &implied);
1279    }
1280
1281    fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1282        let provider_name = &self.provider_name;
1283
1284        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1285            format!(
1286                "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1287                self.params
1288            )
1289        })?;
1290        let mut sources = OrderMap::new();
1291
1292        for (source, params) in &self.sources {
1293            let mut params = params.clone();
1294            // I don't really like this remove call. The issue is that we don't
1295            // have an alternative way to filter this out of the partial source
1296            // definition, and this provider_ option is not valid in the
1297            // EgressPathConfig struct itself; it is a shaping source-specific
1298            // addition to help with setting up shared throttles across providers
1299            params.remove("provider_source_selection_rate");
1300            sources.insert(
1301                source.clone(),
1302                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1303                    format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1304                })?,
1305            );
1306        }
1307
1308        Ok(MergedEntry {
1309            params,
1310            sources,
1311            automation: self.automation.clone(),
1312        })
1313    }
1314
1315    fn hash_into(&self, ctx: &mut Sha256) {
1316        ctx.update(&self.provider_name);
1317        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1318    }
1319}
1320
1321#[cfg(feature = "lua")]
1322#[derive(Deserialize, Serialize, Debug, Clone)]
1323pub enum ProviderMatch {
1324    MXSuffix(String),
1325    DomainSuffix(String),
1326    HostName(String),
1327}
1328
1329#[cfg(feature = "lua")]
1330fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1331    // Limit merging to just the throttle related fields, as their purpose
1332    // is for creating broader scoped limits that cut across normal boundaries
1333    fn is_mergeable(name: &str) -> bool {
1334        match name {
1335            "additional_connection_limits"
1336            | "additional_message_rate_throttles"
1337            | "additional_source_selection_rates" => true,
1338            _ => false,
1339        }
1340    }
1341
1342    for (k, v) in source {
1343        match (tbl.get_mut(k), v.as_table()) {
1344            // Merge Table values together, rather than simply replacing them.
1345            (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1346                for (inner_k, inner_v) in v {
1347                    existing.insert(inner_k.clone(), inner_v.clone());
1348                }
1349            }
1350            _ => {
1351                tbl.insert(k.clone(), v.clone());
1352            }
1353        }
1354    }
1355}
1356
1357#[cfg(feature = "lua")]
1358impl PartialEntry {
1359    fn merge_from(&mut self, mut other: Self) {
1360        if other.replace_base {
1361            self.params = other.params;
1362            self.automation = other.automation;
1363            self.sources = other.sources;
1364        } else {
1365            toml_table_merge_from(&mut self.params, &other.params);
1366
1367            for (source, tbl) in other.sources {
1368                match self.sources.get_mut(&source) {
1369                    Some(existing) => {
1370                        toml_table_merge_from(existing, &tbl);
1371                    }
1372                    None => {
1373                        self.sources.insert(source, tbl);
1374                    }
1375                }
1376            }
1377
1378            self.automation.append(&mut other.automation);
1379        }
1380    }
1381
1382    fn finish(self) -> anyhow::Result<MergedEntry> {
1383        let domain = self.domain_name.unwrap_or_default();
1384
1385        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1386            format!(
1387                "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1388                self.params
1389            )
1390        })?;
1391        let mut sources = OrderMap::new();
1392
1393        for (source, params) in self.sources {
1394            sources.insert(
1395                source.clone(),
1396                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1397                    format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1398                })?,
1399            );
1400        }
1401
1402        Ok(MergedEntry {
1403            params,
1404            sources,
1405            automation: self.automation,
1406        })
1407    }
1408
1409    fn hash_into(&self, ctx: &mut Sha256) {
1410        if let Some(name) = self.domain_name.as_ref() {
1411            ctx.update(name)
1412        }
1413        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1414    }
1415}
1416
1417fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1418where
1419    T: Deserialize<'de>,
1420    D: Deserializer<'de>,
1421{
1422    let result: Result<Vec<T>, _> =
1423        OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1424    match result {
1425        Ok(r) => Ok(r),
1426        Err(err) => Err(serde::de::Error::custom(format!(
1427            "{expecting}.\nThe underlying error message is:\n{err:#}"
1428        ))),
1429    }
1430}
1431
1432fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1433where
1434    D: Deserializer<'de>,
1435{
1436    one_or_many(
1437        deserializer,
1438        "\"action\" field expected either a single Action or an array of Actions",
1439    )
1440}
1441
1442fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1443where
1444    D: Deserializer<'de>,
1445{
1446    string_or_array(
1447        deserializer,
1448        "regex string or array of regex strings for field regex",
1449    )
1450}
1451
1452fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1453where
1454    T: Deserialize<'de> + TryFrom<String>,
1455    <T as TryFrom<String>>::Error: std::fmt::Debug,
1456    D: Deserializer<'de>,
1457{
1458    // This is a Visitor that forwards string types to T's `TryFrom<String>` impl and
1459    // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
1460    // keep the compiler from complaining about T being an unused generic type
1461    // parameter. We need T in order to know the Value type for the Visitor
1462    // impl.
1463    struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1464
1465    impl<'de, T> Visitor<'de> for StringOrArray<T>
1466    where
1467        T: Deserialize<'de> + TryFrom<String>,
1468        <T as TryFrom<String>>::Error: std::fmt::Debug,
1469    {
1470        type Value = Vec<T>;
1471
1472        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1473            formatter.write_str(self.1)
1474        }
1475
1476        fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1477        where
1478            E: serde::de::Error,
1479        {
1480            Ok(vec![
1481                T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1482            ])
1483        }
1484
1485        fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1486        where
1487            S: SeqAccess<'de>,
1488        {
1489            Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1490        }
1491    }
1492
1493    deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1494}
1495
1496#[cfg(feature = "lua")]
1497fn default_true() -> bool {
1498    true
1499}
1500
1501#[cfg(feature = "lua")]
1502pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1503    let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1504
1505    shaping_mod.set(
1506        "load",
1507        lua.create_async_function(
1508            move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1509                let options = match options {
1510                    Some(v) => lua.from_value(v)?,
1511                    None => Default::default(),
1512                };
1513                let shaping = Shaping::merge_files(&paths, &options)
1514                    .await
1515                    .map_err(any_err)?;
1516                Ok(shaping)
1517            },
1518        )?,
1519    )?;
1520
1521    Ok(())
1522}
1523
1524#[cfg(test)]
1525mod test {
1526    use super::*;
1527    use kumo_log_types::RecordType;
1528    use rfc5321::Response;
1529    use std::io::Write;
1530    use tempfile::NamedTempFile;
1531    use uuid::Uuid;
1532
1533    async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1534        let mut files = vec![];
1535        let mut file_names = vec![];
1536
1537        for (i, content) in inputs.iter().enumerate() {
1538            let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1539            shaping_file.write_all(content.as_bytes()).unwrap();
1540            file_names.push(shaping_file.path().to_str().unwrap().to_string());
1541            files.push(shaping_file);
1542        }
1543
1544        Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1545            .await
1546            .unwrap()
1547    }
1548
1549    #[tokio::test]
1550    async fn test_merge_additional() {
1551        let shaping = make_shaping_configs(&[
1552            r#"
1553["example.com"]
1554mx_rollup = false
1555additional_connection_limits = {"first"=10}
1556        "#,
1557            r#"
1558["example.com"]
1559mx_rollup = false
1560additional_connection_limits = {"second"=32}
1561additional_message_rate_throttles = {"second"="100/hr"}
1562        "#,
1563        ])
1564        .await;
1565
1566        let resolved = shaping
1567            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1568            .await
1569            .finish()
1570            .unwrap();
1571
1572        k9::snapshot!(
1573            resolved.params.additional_connection_limits,
1574            r#"
1575{
1576    "first": 10,
1577    "second": 32,
1578}
1579"#
1580        );
1581        k9::snapshot!(
1582            resolved.params.additional_message_rate_throttles,
1583            r#"
1584{
1585    "second": 100/h,
1586}
1587"#
1588        );
1589    }
1590
1591    #[tokio::test]
1592    async fn test_provider_multi_hostname() {
1593        let shaping = make_shaping_configs(&[r#"
1594[provider."yahoo"]
1595match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1596enable_tls = "Required"
1597        "#])
1598        .await;
1599
1600        let resolved = shaping
1601            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1602            .await
1603            .finish()
1604            .unwrap();
1605
1606        k9::assert_equal!(
1607            resolved.params.enable_tls,
1608            crate::egress_path::Tls::Required
1609        );
1610        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1611    }
1612
1613    #[tokio::test]
1614    async fn test_provider_multi_suffix() {
1615        let shaping = make_shaping_configs(&[r#"
1616[provider."yahoo"]
1617match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1618enable_tls = "Required"
1619        "#])
1620        .await;
1621
1622        let resolved = shaping
1623            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1624            .await
1625            .finish()
1626            .unwrap();
1627
1628        k9::assert_equal!(
1629            resolved.params.enable_tls,
1630            crate::egress_path::Tls::Required
1631        );
1632        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1633    }
1634
1635    #[tokio::test]
1636    async fn test_provider() {
1637        let shaping = make_shaping_configs(&[r#"
1638[provider."Office 365"]
1639match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1640enable_tls = "Required"
1641provider_connection_limit = 10
1642provider_max_message_rate = "120/s"
1643
1644[provider."Office 365".sources."new-source"]
1645provider_source_selection_rate = "500/d,max_burst=1"
1646        "#])
1647        .await;
1648
1649        let resolved = shaping
1650            .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1651            .await
1652            .finish()
1653            .unwrap();
1654
1655        k9::assert_equal!(
1656            resolved.params.enable_tls,
1657            crate::egress_path::Tls::Required
1658        );
1659        k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1660
1661        k9::snapshot!(
1662            resolved.params.additional_connection_limits,
1663            r#"
1664{
1665    "shaping-provider-Office 365-invalid.source-limit": 10,
1666}
1667"#
1668        );
1669        k9::snapshot!(
1670            resolved.params.additional_message_rate_throttles,
1671            r#"
1672{
1673    "shaping-provider-Office 365-invalid.source-rate": 120/s,
1674}
1675"#
1676        );
1677        assert!(resolved.params.source_selection_rate.is_none());
1678        assert!(resolved.params.additional_source_selection_rates.is_empty());
1679
1680        let resolved = shaping
1681            .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1682            .await
1683            .finish()
1684            .unwrap();
1685        assert!(resolved.params.source_selection_rate.is_none());
1686        k9::snapshot!(
1687            resolved.params.additional_source_selection_rates,
1688            r#"
1689{
1690    "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1691}
1692"#
1693        );
1694    }
1695
1696    #[tokio::test]
1697    async fn test_rule_matching() {
1698        let shaping = make_shaping_configs(&[r#"
1699[["default".automation]]
1700regex="default"
1701action = {SetConfig={name="connection_limit", value=1}}
1702duration = "1hr"
1703
1704["fake.site"]
1705_treat_domain_name_as_site_name = true
1706
1707[["fake.site".automation]]
1708regex="fake_rollup"
1709action = {SetConfig={name="connection_limit", value=2}}
1710duration = "1hr"
1711
1712["woot.provider"]
1713mx_rollup = false
1714
1715[["woot.provider".automation]]
1716regex="woot_domain"
1717action = {SetConfig={name="connection_limit", value=2}}
1718duration = "1hr"
1719
1720[provider."provider"]
1721match=[{DomainSuffix=".provider"}]
1722
1723[[provider."provider".automation]]
1724regex="provider"
1725action = {SetConfig={name="connection_limit", value=3}}
1726duration = "1hr"
1727match_internal = true
1728
1729"#])
1730        .await;
1731
1732        eprintln!("{:?}", shaping.inner.warnings);
1733
1734        fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1735            JsonLogRecord {
1736                kind: RecordType::TransientFailure,
1737                id: String::new(),
1738                sender: String::new(),
1739                recipient: vec![recipient.to_string()],
1740                queue: String::new(),
1741                site: site.to_string(),
1742                size: 0,
1743                response: Response {
1744                    code: 400,
1745                    command: None,
1746                    enhanced_code: None,
1747                    content: content.to_string(),
1748                },
1749                peer_address: None,
1750                timestamp: Default::default(),
1751                created: Default::default(),
1752                num_attempts: 1,
1753                bounce_classification: Default::default(),
1754                egress_pool: None,
1755                egress_source: None,
1756                source_address: None,
1757                feedback_report: None,
1758                meta: Default::default(),
1759                headers: Default::default(),
1760                delivery_protocol: None,
1761                reception_protocol: None,
1762                nodeid: Uuid::default(),
1763                tls_cipher: None,
1764                tls_protocol_version: None,
1765                tls_peer_subject_name: None,
1766                provider_name: None,
1767                session_id: None,
1768            }
1769        }
1770
1771        let matches = shaping
1772            .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1773            .await
1774            .unwrap();
1775        k9::assert_equal!(
1776            matches[0].regex[0].to_string(),
1777            "default",
1778            "matches against default automation rule"
1779        );
1780
1781        let matches = shaping
1782            .match_rules(&make_record(
1783                "KumoMTA internal: default",
1784                "user@example.com",
1785                "dummy_site",
1786            ))
1787            .await
1788            .unwrap();
1789        assert!(matches.is_empty(), "internal bounce should not match");
1790
1791        let matches = shaping
1792            .match_rules(&make_record(
1793                "woot_domain",
1794                "user@woot.provider",
1795                "dummy_site",
1796            ))
1797            .await
1798            .unwrap();
1799        k9::assert_equal!(
1800            matches[0].regex[0].to_string(),
1801            "woot_domain",
1802            "matches against domain rule with mx_rollup=false"
1803        );
1804
1805        let matches = shaping
1806            .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1807            .await
1808            .unwrap();
1809        k9::assert_equal!(
1810            matches[0].regex[0].to_string(),
1811            "fake_rollup",
1812            "matches against domain rule with mx_rollup=true"
1813        );
1814
1815        let matches = shaping
1816            .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1817            .await
1818            .unwrap();
1819        k9::assert_equal!(
1820            matches[0].regex[0].to_string(),
1821            "provider",
1822            "matches against provider rule"
1823        );
1824
1825        let matches = shaping
1826            .match_rules(&make_record(
1827                "KumoMTA internal: provider",
1828                "user@woot.provider",
1829                "dummy_site",
1830            ))
1831            .await
1832            .unwrap();
1833        k9::assert_equal!(
1834            matches[0].regex[0].to_string(),
1835            "provider",
1836            "internal response matches against provider rule"
1837        );
1838    }
1839
1840    #[tokio::test]
1841    async fn test_defaults() {
1842        let shaping = make_shaping_configs(&[
1843            r#"
1844["default"]
1845connection_limit = 10
1846max_connection_rate = "100/min"
1847max_deliveries_per_connection = 100
1848max_message_rate = "100/s"
1849idle_timeout = "60s"
1850data_timeout = "30s"
1851data_dot_timeout = "60s"
1852enable_tls = "Opportunistic"
1853consecutive_connection_failures_before_delay = 100
1854
1855[["default".automation]]
1856regex=[
1857        '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1858        '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1859        '/has been temporarily rate limited due to IP reputation/',
1860        '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1861        '/Server busy\. Please try again later from/'
1862]
1863action = [
1864        {SetConfig={name="max_message_rate", value="1/minute"}},
1865        {SetConfig={name="connection_limit", value=1}}
1866]
1867duration = "90m"
1868
1869[["default".automation]]
1870regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1871action = {SetConfig={name="enable_tls", value="Disabled"}}
1872duration = "30 days"
1873
1874["gmail.com"]
1875max_deliveries_per_connection = 50
1876connection_limit = 5
1877enable_tls = "Required"
1878consecutive_connection_failures_before_delay = 5
1879
1880["yahoo.com"]
1881max_deliveries_per_connection = 20
1882
1883[["yahoo.com".automation]]
1884regex = "\\[TS04\\]"
1885action = "Suspend"
1886duration = "2 hours"
1887
1888["comcast.net"]
1889connection_limit = 25
1890max_deliveries_per_connection = 250
1891enable_tls = "Required"
1892idle_timeout = "30s"
1893consecutive_connection_failures_before_delay = 24
1894
1895["mail.com"]
1896max_deliveries_per_connection = 100
1897
1898["orange.fr"]
1899connection_limit = 3
1900
1901["smtp.mailgun.com"]
1902connection_limit = 7000
1903max_deliveries_per_connection = 3
1904
1905["example.com"]
1906mx_rollup = false
1907max_deliveries_per_connection = 100
1908connection_limit = 3
1909
1910["example.com".sources."my source name"]
1911connection_limit = 5
1912        "#,
1913        ])
1914        .await;
1915
1916        let default = shaping
1917            .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1918            .await
1919            .finish()
1920            .unwrap();
1921        k9::snapshot!(
1922            default,
1923            r#"
1924MergedEntry {
1925    params: EgressPathConfig {
1926        connection_limit: 10,
1927        additional_connection_limits: {},
1928        enable_tls: Opportunistic,
1929        enable_mta_sts: true,
1930        enable_dane: false,
1931        enable_pipelining: true,
1932        enable_rset: true,
1933        tls_prefer_openssl: false,
1934        tls_certificate: None,
1935        tls_private_key: None,
1936        openssl_cipher_list: None,
1937        openssl_cipher_suites: None,
1938        openssl_options: None,
1939        rustls_cipher_suites: [],
1940        client_timeouts: SmtpClientTimeouts {
1941            connect_timeout: 60s,
1942            banner_timeout: 60s,
1943            ehlo_timeout: 300s,
1944            mail_from_timeout: 300s,
1945            rcpt_to_timeout: 300s,
1946            data_timeout: 30s,
1947            data_dot_timeout: 60s,
1948            rset_timeout: 5s,
1949            idle_timeout: 60s,
1950            starttls_timeout: 5s,
1951            auth_timeout: 60s,
1952        },
1953        system_shutdown_timeout: None,
1954        max_ready: 1024,
1955        consecutive_connection_failures_before_delay: 100,
1956        smtp_port: 25,
1957        smtp_auth_plain_username: None,
1958        smtp_auth_plain_password: None,
1959        allow_smtp_auth_plain_without_tls: false,
1960        max_message_rate: Some(
1961            100/s,
1962        ),
1963        additional_message_rate_throttles: {},
1964        source_selection_rate: None,
1965        additional_source_selection_rates: {},
1966        max_connection_rate: Some(
1967            100/m,
1968        ),
1969        max_deliveries_per_connection: 100,
1970        max_recipients_per_batch: 100,
1971        prohibited_hosts: {
1972            "127.0.0.0/8",
1973            "::1",
1974        },
1975        skip_hosts: {},
1976        ehlo_domain: None,
1977        aggressive_connection_opening: false,
1978        refresh_interval: 60s,
1979        refresh_strategy: Ttl,
1980        dispatcher_wakeup_strategy: Aggressive,
1981        maintainer_wakeup_strategy: Aggressive,
1982        provider_name: None,
1983        remember_broken_tls: None,
1984        opportunistic_tls_reconnect_on_failed_handshake: false,
1985        use_lmtp: false,
1986        reconnect_strategy: ConnectNextHost,
1987        readyq_pool_name: None,
1988        low_memory_reduction_policy: ShrinkDataAndMeta,
1989        no_memory_reduction_policy: ShrinkDataAndMeta,
1990        try_next_host_on_transport_error: false,
1991        ignore_8bit_checks: false,
1992    },
1993    sources: {},
1994    automation: [
1995        Rule {
1996            regex: [
1997                Regex(
1998                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
1999                ),
2000                Regex(
2001                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2002                ),
2003                Regex(
2004                    /has been temporarily rate limited due to IP reputation/,
2005                ),
2006                Regex(
2007                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2008                ),
2009                Regex(
2010                    /Server busy\. Please try again later from/,
2011                ),
2012            ],
2013            action: [
2014                SetConfig(
2015                    EgressPathConfigValue {
2016                        name: "max_message_rate",
2017                        value: HashableTomlValue {
2018                            value: String(
2019                                "1/minute",
2020                            ),
2021                        },
2022                    },
2023                ),
2024                SetConfig(
2025                    EgressPathConfigValue {
2026                        name: "connection_limit",
2027                        value: HashableTomlValue {
2028                            value: Integer(
2029                                1,
2030                            ),
2031                        },
2032                    },
2033                ),
2034            ],
2035            trigger: Immediate,
2036            duration: 5400s,
2037            was_rollup: false,
2038            match_internal: false,
2039        },
2040        Rule {
2041            regex: [
2042                Regex(
2043                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2044                ),
2045            ],
2046            action: [
2047                SetConfig(
2048                    EgressPathConfigValue {
2049                        name: "enable_tls",
2050                        value: HashableTomlValue {
2051                            value: String(
2052                                "Disabled",
2053                            ),
2054                        },
2055                    },
2056                ),
2057            ],
2058            trigger: Immediate,
2059            duration: 2592000s,
2060            was_rollup: false,
2061            match_internal: false,
2062        },
2063    ],
2064}
2065"#
2066        );
2067
2068        let example_com = shaping
2069            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2070            .await
2071            .finish()
2072            .unwrap();
2073        k9::snapshot!(
2074            example_com,
2075            r#"
2076MergedEntry {
2077    params: EgressPathConfig {
2078        connection_limit: 3,
2079        additional_connection_limits: {},
2080        enable_tls: Opportunistic,
2081        enable_mta_sts: true,
2082        enable_dane: false,
2083        enable_pipelining: true,
2084        enable_rset: true,
2085        tls_prefer_openssl: false,
2086        tls_certificate: None,
2087        tls_private_key: None,
2088        openssl_cipher_list: None,
2089        openssl_cipher_suites: None,
2090        openssl_options: None,
2091        rustls_cipher_suites: [],
2092        client_timeouts: SmtpClientTimeouts {
2093            connect_timeout: 60s,
2094            banner_timeout: 60s,
2095            ehlo_timeout: 300s,
2096            mail_from_timeout: 300s,
2097            rcpt_to_timeout: 300s,
2098            data_timeout: 30s,
2099            data_dot_timeout: 60s,
2100            rset_timeout: 5s,
2101            idle_timeout: 60s,
2102            starttls_timeout: 5s,
2103            auth_timeout: 60s,
2104        },
2105        system_shutdown_timeout: None,
2106        max_ready: 1024,
2107        consecutive_connection_failures_before_delay: 100,
2108        smtp_port: 25,
2109        smtp_auth_plain_username: None,
2110        smtp_auth_plain_password: None,
2111        allow_smtp_auth_plain_without_tls: false,
2112        max_message_rate: Some(
2113            100/s,
2114        ),
2115        additional_message_rate_throttles: {},
2116        source_selection_rate: None,
2117        additional_source_selection_rates: {},
2118        max_connection_rate: Some(
2119            100/m,
2120        ),
2121        max_deliveries_per_connection: 100,
2122        max_recipients_per_batch: 100,
2123        prohibited_hosts: {
2124            "127.0.0.0/8",
2125            "::1",
2126        },
2127        skip_hosts: {},
2128        ehlo_domain: None,
2129        aggressive_connection_opening: false,
2130        refresh_interval: 60s,
2131        refresh_strategy: Ttl,
2132        dispatcher_wakeup_strategy: Aggressive,
2133        maintainer_wakeup_strategy: Aggressive,
2134        provider_name: None,
2135        remember_broken_tls: None,
2136        opportunistic_tls_reconnect_on_failed_handshake: false,
2137        use_lmtp: false,
2138        reconnect_strategy: ConnectNextHost,
2139        readyq_pool_name: None,
2140        low_memory_reduction_policy: ShrinkDataAndMeta,
2141        no_memory_reduction_policy: ShrinkDataAndMeta,
2142        try_next_host_on_transport_error: false,
2143        ignore_8bit_checks: false,
2144    },
2145    sources: {
2146        "my source name": EgressPathConfig {
2147            connection_limit: 5,
2148            additional_connection_limits: {},
2149            enable_tls: Opportunistic,
2150            enable_mta_sts: true,
2151            enable_dane: false,
2152            enable_pipelining: true,
2153            enable_rset: true,
2154            tls_prefer_openssl: false,
2155            tls_certificate: None,
2156            tls_private_key: None,
2157            openssl_cipher_list: None,
2158            openssl_cipher_suites: None,
2159            openssl_options: None,
2160            rustls_cipher_suites: [],
2161            client_timeouts: SmtpClientTimeouts {
2162                connect_timeout: 60s,
2163                banner_timeout: 60s,
2164                ehlo_timeout: 300s,
2165                mail_from_timeout: 300s,
2166                rcpt_to_timeout: 300s,
2167                data_timeout: 300s,
2168                data_dot_timeout: 300s,
2169                rset_timeout: 5s,
2170                idle_timeout: 5s,
2171                starttls_timeout: 5s,
2172                auth_timeout: 60s,
2173            },
2174            system_shutdown_timeout: None,
2175            max_ready: 1024,
2176            consecutive_connection_failures_before_delay: 100,
2177            smtp_port: 25,
2178            smtp_auth_plain_username: None,
2179            smtp_auth_plain_password: None,
2180            allow_smtp_auth_plain_without_tls: false,
2181            max_message_rate: None,
2182            additional_message_rate_throttles: {},
2183            source_selection_rate: None,
2184            additional_source_selection_rates: {},
2185            max_connection_rate: None,
2186            max_deliveries_per_connection: 1024,
2187            max_recipients_per_batch: 100,
2188            prohibited_hosts: {
2189                "127.0.0.0/8",
2190                "::1",
2191            },
2192            skip_hosts: {},
2193            ehlo_domain: None,
2194            aggressive_connection_opening: false,
2195            refresh_interval: 60s,
2196            refresh_strategy: Ttl,
2197            dispatcher_wakeup_strategy: Aggressive,
2198            maintainer_wakeup_strategy: Aggressive,
2199            provider_name: None,
2200            remember_broken_tls: None,
2201            opportunistic_tls_reconnect_on_failed_handshake: false,
2202            use_lmtp: false,
2203            reconnect_strategy: ConnectNextHost,
2204            readyq_pool_name: None,
2205            low_memory_reduction_policy: ShrinkDataAndMeta,
2206            no_memory_reduction_policy: ShrinkDataAndMeta,
2207            try_next_host_on_transport_error: false,
2208            ignore_8bit_checks: false,
2209        },
2210    },
2211    automation: [
2212        Rule {
2213            regex: [
2214                Regex(
2215                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2216                ),
2217                Regex(
2218                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2219                ),
2220                Regex(
2221                    /has been temporarily rate limited due to IP reputation/,
2222                ),
2223                Regex(
2224                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2225                ),
2226                Regex(
2227                    /Server busy\. Please try again later from/,
2228                ),
2229            ],
2230            action: [
2231                SetConfig(
2232                    EgressPathConfigValue {
2233                        name: "max_message_rate",
2234                        value: HashableTomlValue {
2235                            value: String(
2236                                "1/minute",
2237                            ),
2238                        },
2239                    },
2240                ),
2241                SetConfig(
2242                    EgressPathConfigValue {
2243                        name: "connection_limit",
2244                        value: HashableTomlValue {
2245                            value: Integer(
2246                                1,
2247                            ),
2248                        },
2249                    },
2250                ),
2251            ],
2252            trigger: Immediate,
2253            duration: 5400s,
2254            was_rollup: false,
2255            match_internal: false,
2256        },
2257        Rule {
2258            regex: [
2259                Regex(
2260                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2261                ),
2262            ],
2263            action: [
2264                SetConfig(
2265                    EgressPathConfigValue {
2266                        name: "enable_tls",
2267                        value: HashableTomlValue {
2268                            value: String(
2269                                "Disabled",
2270                            ),
2271                        },
2272                    },
2273                ),
2274            ],
2275            trigger: Immediate,
2276            duration: 2592000s,
2277            was_rollup: false,
2278            match_internal: false,
2279        },
2280    ],
2281}
2282"#
2283        );
2284
2285        // The site name here will need to be updated if yahoo changes
2286        // their MX records
2287        let yahoo_com = shaping
2288            .get_egress_path_config(
2289                "yahoo.com",
2290                "invalid.source",
2291                "(mta5|mta6|mta7).am0.yahoodns.net",
2292            )
2293            .await
2294            .finish()
2295            .unwrap();
2296        k9::snapshot!(
2297            yahoo_com,
2298            r#"
2299MergedEntry {
2300    params: EgressPathConfig {
2301        connection_limit: 10,
2302        additional_connection_limits: {},
2303        enable_tls: Opportunistic,
2304        enable_mta_sts: true,
2305        enable_dane: false,
2306        enable_pipelining: true,
2307        enable_rset: true,
2308        tls_prefer_openssl: false,
2309        tls_certificate: None,
2310        tls_private_key: None,
2311        openssl_cipher_list: None,
2312        openssl_cipher_suites: None,
2313        openssl_options: None,
2314        rustls_cipher_suites: [],
2315        client_timeouts: SmtpClientTimeouts {
2316            connect_timeout: 60s,
2317            banner_timeout: 60s,
2318            ehlo_timeout: 300s,
2319            mail_from_timeout: 300s,
2320            rcpt_to_timeout: 300s,
2321            data_timeout: 30s,
2322            data_dot_timeout: 60s,
2323            rset_timeout: 5s,
2324            idle_timeout: 60s,
2325            starttls_timeout: 5s,
2326            auth_timeout: 60s,
2327        },
2328        system_shutdown_timeout: None,
2329        max_ready: 1024,
2330        consecutive_connection_failures_before_delay: 100,
2331        smtp_port: 25,
2332        smtp_auth_plain_username: None,
2333        smtp_auth_plain_password: None,
2334        allow_smtp_auth_plain_without_tls: false,
2335        max_message_rate: Some(
2336            100/s,
2337        ),
2338        additional_message_rate_throttles: {},
2339        source_selection_rate: None,
2340        additional_source_selection_rates: {},
2341        max_connection_rate: Some(
2342            100/m,
2343        ),
2344        max_deliveries_per_connection: 20,
2345        max_recipients_per_batch: 100,
2346        prohibited_hosts: {
2347            "127.0.0.0/8",
2348            "::1",
2349        },
2350        skip_hosts: {},
2351        ehlo_domain: None,
2352        aggressive_connection_opening: false,
2353        refresh_interval: 60s,
2354        refresh_strategy: Ttl,
2355        dispatcher_wakeup_strategy: Aggressive,
2356        maintainer_wakeup_strategy: Aggressive,
2357        provider_name: None,
2358        remember_broken_tls: None,
2359        opportunistic_tls_reconnect_on_failed_handshake: false,
2360        use_lmtp: false,
2361        reconnect_strategy: ConnectNextHost,
2362        readyq_pool_name: None,
2363        low_memory_reduction_policy: ShrinkDataAndMeta,
2364        no_memory_reduction_policy: ShrinkDataAndMeta,
2365        try_next_host_on_transport_error: false,
2366        ignore_8bit_checks: false,
2367    },
2368    sources: {},
2369    automation: [
2370        Rule {
2371            regex: [
2372                Regex(
2373                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2374                ),
2375                Regex(
2376                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2377                ),
2378                Regex(
2379                    /has been temporarily rate limited due to IP reputation/,
2380                ),
2381                Regex(
2382                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2383                ),
2384                Regex(
2385                    /Server busy\. Please try again later from/,
2386                ),
2387            ],
2388            action: [
2389                SetConfig(
2390                    EgressPathConfigValue {
2391                        name: "max_message_rate",
2392                        value: HashableTomlValue {
2393                            value: String(
2394                                "1/minute",
2395                            ),
2396                        },
2397                    },
2398                ),
2399                SetConfig(
2400                    EgressPathConfigValue {
2401                        name: "connection_limit",
2402                        value: HashableTomlValue {
2403                            value: Integer(
2404                                1,
2405                            ),
2406                        },
2407                    },
2408                ),
2409            ],
2410            trigger: Immediate,
2411            duration: 5400s,
2412            was_rollup: false,
2413            match_internal: false,
2414        },
2415        Rule {
2416            regex: [
2417                Regex(
2418                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2419                ),
2420            ],
2421            action: [
2422                SetConfig(
2423                    EgressPathConfigValue {
2424                        name: "enable_tls",
2425                        value: HashableTomlValue {
2426                            value: String(
2427                                "Disabled",
2428                            ),
2429                        },
2430                    },
2431                ),
2432            ],
2433            trigger: Immediate,
2434            duration: 2592000s,
2435            was_rollup: false,
2436            match_internal: false,
2437        },
2438        Rule {
2439            regex: [
2440                Regex(
2441                    \[TS04\],
2442                ),
2443            ],
2444            action: [
2445                Suspend,
2446            ],
2447            trigger: Immediate,
2448            duration: 7200s,
2449            was_rollup: false,
2450            match_internal: false,
2451        },
2452    ],
2453}
2454"#
2455        );
2456    }
2457
2458    #[tokio::test]
2459    async fn test_load_default_shaping_toml() {
2460        Shaping::merge_files(
2461            &["../../assets/policy-extras/shaping.toml".into()],
2462            &ShapingMergeOptions::default(),
2463        )
2464        .await
2465        .unwrap();
2466    }
2467}