1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39 type Error = fancy_regex::Error;
40
41 fn try_from(s: String) -> fancy_regex::Result<Self> {
42 Ok(Self(fancy_regex::Regex::new(&s)?))
43 }
44}
45
46impl From<Regex> for String {
47 fn from(r: Regex) -> String {
48 r.0.as_str().to_string()
49 }
50}
51
52impl std::ops::Deref for Regex {
53 type Target = fancy_regex::Regex;
54 fn deref(&self) -> &Self::Target {
55 &self.0
56 }
57}
58
59impl std::hash::Hash for Regex {
60 fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61 self.0.as_str().hash(hasher)
62 }
63}
64
65#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73 value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77 fn from(value: toml::Value) -> Self {
78 Self { value }
79 }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83 fn from(value: HashableTomlValue) -> toml::Value {
84 value.value
85 }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89 type Target = toml::Value;
90 fn deref(&self) -> &toml::Value {
91 &self.value
92 }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97 H: Hasher,
98{
99 match value {
100 toml::Value::Boolean(v) => v.hash(h),
101 toml::Value::Datetime(v) => {
102 if let Some(d) = &v.date {
103 d.year.hash(h);
104 d.month.hash(h);
105 d.day.hash(h);
106 }
107 if let Some(t) = &v.time {
108 t.hour.hash(h);
109 t.minute.hash(h);
110 t.second.hash(h);
111 t.nanosecond.hash(h);
112 }
113 if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114 minutes.hash(h);
115 }
116 }
117 toml::Value::String(v) => v.hash(h),
118 toml::Value::Integer(v) => v.hash(h),
119 toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120 toml::Value::Array(a) => {
121 for v in a.iter() {
122 hash_toml(v, h);
123 }
124 }
125 toml::Value::Table(m) => {
126 for (k, v) in m.iter() {
127 k.hash(h);
128 hash_toml(v, h);
129 }
130 }
131 }
132}
133
134impl Hash for HashableTomlValue {
135 fn hash<H>(&self, h: &mut H)
136 where
137 H: Hasher,
138 {
139 hash_toml(&self.value, h);
140 }
141}
142
143#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147 try_from = "EgressPathConfigValueUnchecked",
148 into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151 pub name: String,
152 pub value: HashableTomlValue,
153}
154
155#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162 pub name: String,
163 pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167 type Error = anyhow::Error;
168 fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169 let mut map = toml::map::Map::new();
170 map.insert(config.name.clone(), config.value.clone());
171 let table = toml::Value::Table(map);
172
173 EgressPathConfig::deserialize(table)?;
176
177 Ok(EgressPathConfigValue {
179 name: config.name,
180 value: HashableTomlValue {
181 value: config.value,
182 },
183 })
184 }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188 fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189 EgressPathConfigValueUnchecked {
190 name: config.name,
191 value: config.value.value,
192 }
193 }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198 Suspend,
199 SetConfig(EgressPathConfigValue),
200 SuspendTenant,
201 SuspendCampaign,
202 SetDomainConfig(EgressPathConfigValue),
203 Bounce,
204 BounceTenant,
205 BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210 #[default]
212 Immediate,
213 Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221 #[serde(deserialize_with = "regex_string_or_array")]
222 pub regex: Vec<Regex>,
223
224 #[serde(deserialize_with = "one_or_many_action")]
225 pub action: Vec<Action>,
226
227 #[serde(default)]
228 pub trigger: Trigger,
229
230 #[serde(with = "duration_serde")]
231 pub duration: Duration,
232
233 #[serde(skip)]
234 pub was_rollup: bool,
235
236 #[serde(default)]
240 pub match_internal: bool,
241}
242
243impl Rule {
244 pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245 if is_internal && !self.match_internal {
246 return false;
247 }
248 self.regex
249 .iter()
250 .any(|r| r.is_match(response).unwrap_or(false))
251 }
252
253 pub fn clone_and_set_rollup(&self) -> Self {
254 let mut result = self.clone();
255 result.was_rollup = true;
256 result
257 }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263 by_site: OrderMap<String, PartialEntry>,
264 by_domain: OrderMap<String, PartialEntry>,
265 by_provider: OrderMap<String, ProviderEntry>,
266 warnings: Vec<String>,
267 errors: Vec<String>,
268 hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273 pub async fn get_egress_path_config(
274 &self,
275 domain: &str,
276 egress_source: &str,
277 site_name: &str,
278 ) -> PartialEntry {
279 let mut params = PartialEntry::default();
280
281 if let Some(default) = self.by_domain.get("default") {
283 params.merge_from(default.clone());
284 }
285
286 let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291 if is_domain_name {
292 let mut prov_with_sources = vec![];
293
294 for prov in self.by_provider.values() {
295 if prov.domain_matches(domain).await {
296 toml_table_merge_from(&mut params.params, &prov.params);
297 prov.apply_provider_params_to(egress_source, &mut params.params);
298
299 if !prov.sources.is_empty() {
300 prov_with_sources.push(prov);
304 }
305 }
306 }
307
308 for prov in prov_with_sources {
310 if let Some(source) = prov.sources.get(egress_source) {
311 toml_table_merge_from(&mut params.params, source);
312 prov.apply_provider_params_to(egress_source, &mut params.params);
313 }
314 }
315 }
316
317 if let Some(by_site) = self.by_site.get(site_name) {
319 params.merge_from(by_site.clone());
320 }
321
322 if let Some(by_domain) = self.by_domain.get(domain) {
324 params.merge_from(by_domain.clone());
325 }
326
327 if let Some(by_site) = self.by_site.get(site_name) {
329 if let Some(source) = by_site.sources.get(egress_source) {
330 toml_table_merge_from(&mut params.params, source);
331 }
332 }
333
334 if let Some(by_domain) = self.by_domain.get(domain) {
336 if let Some(source) = by_domain.sources.get(egress_source) {
337 toml_table_merge_from(&mut params.params, source);
338 }
339 }
340
341 params
342 }
343
344 pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345 use rfc5321::ForwardPath;
346 let recipient = ForwardPath::try_from(
348 record
349 .recipient
350 .first()
351 .ok_or_else(|| anyhow::anyhow!("no recipients!?"))?
352 .as_str(),
353 )
354 .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
355
356 let recipient = match recipient {
357 ForwardPath::Postmaster => {
358 return Ok(vec![]);
361 }
362 ForwardPath::Path(path) => path.mailbox,
363 };
364 let domain = recipient.domain.to_string();
365
366 let source = record.egress_source.as_deref().unwrap_or("unspecified");
368 let site_name = record
381 .site
382 .trim_start_matches(&format!("{source}->"))
383 .trim_end_matches("@smtp_client")
384 .to_string();
385
386 Ok(self.match_rules_impl(record, &domain, &site_name).await)
387 }
388
389 pub async fn match_rules_impl(
390 &self,
391 record: &JsonLogRecord,
392 domain: &str,
393 site_name: &str,
394 ) -> Vec<Rule> {
395 let mut result = vec![];
396 let response = record.response.to_single_line();
397 tracing::trace!("Consider rules for {response}");
398
399 let is_internal = record.response.content.starts_with("KumoMTA internal: ");
400
401 if let Some(default) = self.by_domain.get("default") {
402 for rule in &default.automation {
403 tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
404 if rule.matches(is_internal, &response) {
405 result.push(rule.clone_and_set_rollup());
412 }
413 }
414 }
415
416 for prov in self.by_provider.values() {
417 if prov.domain_matches(domain).await {
418 for rule in &prov.automation {
419 tracing::trace!(
420 "Consider provider \"{}\" rule {rule:?} for {response}",
421 prov.provider_name
422 );
423 if rule.matches(is_internal, &response) {
424 result.push(rule.clone());
425 }
426 }
427 }
428 }
429
430 if let Some(by_site) = self.by_site.get(site_name) {
432 for rule in &by_site.automation {
433 tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
434 if rule.matches(is_internal, &response) {
435 result.push(rule.clone_and_set_rollup());
436 }
437 }
438 }
439
440 if let Some(by_domain) = self.by_domain.get(domain) {
442 for rule in &by_domain.automation {
443 tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
444 if rule.matches(is_internal, &response) {
445 result.push(rule.clone());
446 }
447 }
448 }
449
450 result
451 }
452}
453
454#[cfg(feature = "lua")]
455#[derive(Debug, Default, Clone, mlua::FromLua)]
456pub struct Shaping {
457 inner: Arc<ShapingInner>,
458}
459
460#[cfg(feature = "lua")]
461fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
462 let d = &mut serde_json::Deserializer::from_str(json);
463 Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
468 let d = toml::Deserializer::parse(toml)?;
469 Ok(serde_path_to_error::deserialize(d)?)
470}
471
472#[cfg(feature = "lua")]
473#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
474pub enum CheckLevel {
475 #[default]
476 Ignore,
477 Warn,
478 Error,
479}
480#[cfg(feature = "lua")]
481impl std::str::FromStr for CheckLevel {
482 type Err = String;
483
484 fn from_str(s: &str) -> Result<Self, String> {
485 if s.eq_ignore_ascii_case("ignore") {
486 Ok(Self::Ignore)
487 } else if s.eq_ignore_ascii_case("warn") {
488 Ok(Self::Warn)
489 } else if s.eq_ignore_ascii_case("error") {
490 Ok(Self::Error)
491 } else {
492 Err(format!(
493 "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
494 ))
495 }
496 }
497}
498
499#[cfg(feature = "lua")]
500#[derive(Default)]
501struct Collector {
502 warnings: Vec<String>,
503 errors: Vec<String>,
504}
505
506#[cfg(feature = "lua")]
507impl Collector {
508 fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
509 match level {
510 CheckLevel::Ignore => {}
511 CheckLevel::Warn => self.warnings.push(msg.into()),
512 CheckLevel::Error => self.errors.push(msg.into()),
513 }
514 }
515}
516
517#[cfg(feature = "lua")]
518#[derive(Debug, Clone, Deserialize)]
519#[serde_as]
520#[serde(deny_unknown_fields)]
521pub struct ShapingMergeOptions {
522 #[serde(default)]
523 pub provider_overlap: CheckLevel,
524 #[serde(default)]
525 pub dns_fail: CheckLevel,
526 #[serde(default)]
527 pub null_mx: CheckLevel,
528 #[serde(default)]
529 pub aliased_site: CheckLevel,
530 #[serde(default)]
531 pub skip_remote: bool,
532 #[serde(default)]
533 pub remote_load: CheckLevel,
534 #[serde(default)]
535 pub local_load: CheckLevel,
536 #[serde(default, with = "duration_serde")]
537 pub http_timeout: Option<Duration>,
538}
539
540#[cfg(feature = "lua")]
541impl Default for ShapingMergeOptions {
542 fn default() -> Self {
543 Self {
544 provider_overlap: CheckLevel::Ignore,
545 dns_fail: CheckLevel::Ignore,
546 null_mx: CheckLevel::Ignore,
547 aliased_site: CheckLevel::Ignore,
548 skip_remote: false,
549 remote_load: CheckLevel::Ignore,
550 local_load: CheckLevel::Error,
551 http_timeout: None,
552 }
553 }
554}
555
556#[cfg(feature = "lua")]
557impl Shaping {
558 async fn load_from_file(
559 path: &str,
560 options: &ShapingMergeOptions,
561 collector: &mut Collector,
562 ) -> anyhow::Result<ShapingFile> {
563 let (data, level): (String, CheckLevel) =
564 if path.starts_with("http://") || path.starts_with("https://") {
565 if options.skip_remote {
566 collector.push(
567 CheckLevel::Warn,
568 format!("Ignoring {path} because skip_remote is set to true"),
569 );
570 return Ok(ShapingFile::default());
571 }
572
573 async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
579 tokio::time::timeout(timeout, async {
580 reqwest::Client::builder()
581 .timeout(timeout)
582 .connect_timeout(timeout)
583 .read_timeout(timeout)
584 .build()?
585 .get(url)
586 .send()
587 .await
588 .with_context(|| format!("making HTTP request to {url}"))?
589 .text()
590 .await
591 .with_context(|| format!("reading text from {url}"))
592 })
593 .await
594 .with_context(|| format!("timeout making HTTP request to {url}"))?
595 }
596
597 let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
598
599 match http_get(path, timeout).await {
600 Ok(s) => (s, options.remote_load),
601 Err(err) => {
602 tracing::error!("{err:#}. Ignoring this shaping source for now");
603 collector.push(
604 options.remote_load,
605 format!("remote shaping source {path} error: {err:#}"),
606 );
607 return Ok(ShapingFile::default());
608 }
609 }
610 } else {
611 match std::fs::read_to_string(path)
612 .with_context(|| format!("loading data from file {path}"))
613 {
614 Err(err) => {
615 collector.push(
616 options.local_load,
617 format!("local shaping source {path} error: {err:#}"),
618 );
619 return Ok(ShapingFile::default());
620 }
621 Ok(s) => (s, options.local_load),
622 }
623 };
624
625 if path.ends_with(".toml") {
626 from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
627 } else if path.ends_with(".json") {
628 from_json(&data).with_context(|| format!("parsing json from file {path}"))
629 } else {
630 let mut errors = vec![];
632 match from_toml(&data) {
633 Ok(s) => return Ok(s),
634 Err(err) => errors.push(format!("as toml: {err:#}")),
635 }
636 match from_json(&data) {
637 Ok(s) => return Ok(s),
638 Err(err) => errors.push(format!("as json: {err:#}")),
639 }
640
641 collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
642 Ok(ShapingFile::default())
643 }
644 }
645
646 pub async fn merge_files(
647 files: &[String],
648 options: &ShapingMergeOptions,
649 ) -> anyhow::Result<Self> {
650 use futures_util::stream::FuturesUnordered;
651 use futures_util::StreamExt;
652
653 let mut collector = Collector::default();
654 let mut loaded = vec![];
655 for p in files {
656 loaded.push(Self::load_from_file(p, options, &mut collector).await?);
657 }
658
659 let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
660 let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
661 let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
662 let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
663
664 let mut mx = std::collections::HashMap::new();
668 let mut lookups = FuturesUnordered::new();
669 for item in &loaded {
670 for (domain, partial) in &item.domains {
671 if partial.mx_rollup {
672 let domain = domain.to_string();
673 lookups.push(tokio::spawn(async move {
674 let mx_result = MailExchanger::resolve(&domain).await;
675 (domain, mx_result)
676 }));
677 }
678 }
679 }
680
681 while let Some(Ok((domain, result))) = lookups.next().await {
682 mx.insert(domain, result);
683 }
684
685 for mut item in loaded {
686 if let Some(mut partial) = item.default.take() {
687 let domain = "default";
688 partial.domain_name.replace(domain.to_string());
689 match by_domain.get_mut(domain) {
690 Some(existing) => {
691 existing.merge_from(partial);
692 }
693 None => {
694 by_domain.insert(domain.to_string(), partial);
695 }
696 }
697 }
698
699 for (domain, mut partial) in item.domains {
700 partial.domain_name.replace(domain.clone());
701
702 if let Ok(name) = fully_qualify(&domain) {
703 if name.num_labels() == 1 {
704 collector.push(
705 CheckLevel::Warn,
706 format!(
707 "Entry for domain '{domain}' consists of a \
708 single DNS label. Domain names in TOML sections \
709 need to be quoted like '[\"{domain}.com\"]` otherwise \
710 the '.' will create a nested table rather than being \
711 added to the domain name."
712 ),
713 );
714 }
715 }
716
717 #[cfg(test)]
718 if partial._treat_domain_name_as_site_name {
719 match by_site.get_mut(&domain) {
720 Some(existing) => {
721 existing.merge_from(partial);
722 }
723 None => {
724 by_site.insert(domain.to_string(), partial);
725 }
726 }
727 continue;
728 }
729 if partial.mx_rollup {
730 let mx = match mx.get(&domain) {
731 Some(Ok(mx)) => mx,
732 Some(Err(err)) => {
733 collector.push(
734 options.dns_fail,
735 format!(
736 "error resolving MX for {domain}: {err:#}. \
737 Ignoring the shaping config for that domain."
738 ),
739 );
740 continue;
741 }
742 None => {
743 collector.push(
744 options.dns_fail,
745 format!(
746 "We didn't try to resolve the MX for {domain} for some reason!?. \
747 Ignoring the shaping config for that domain."
748 ),
749 );
750 continue;
751 }
752 };
753
754 if mx.site_name.is_empty() {
755 collector.push(
756 options.null_mx,
757 format!(
758 "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
759 Ignoring the shaping config for that domain."),
760 );
761 continue;
762 }
763
764 site_aliases
765 .entry(mx.site_name.to_string())
766 .or_default()
767 .push(domain.to_string());
768
769 match by_site.get_mut(&mx.site_name) {
770 Some(existing) => {
771 existing.merge_from(partial);
772 }
773 None => {
774 by_site.insert(mx.site_name.clone(), partial);
775 }
776 }
777 } else {
778 match by_domain.get_mut(&domain) {
779 Some(existing) => {
780 existing.merge_from(partial);
781 }
782 None => {
783 by_domain.insert(domain, partial);
784 }
785 }
786 }
787 }
788
789 for (provider, mut prov) in item.provider {
790 prov.provider_name = provider.to_string();
791 match by_provider.get_mut(&provider) {
792 Some(existing) => {
793 existing.merge_from(prov);
794 }
795 None => {
796 by_provider.insert(provider.to_string(), prov);
797 }
798 }
799 }
800 }
801
802 for (site, partial) in &by_site {
803 partial
804 .clone()
805 .finish()
806 .with_context(|| format!("site: {site}"))?;
807 }
808
809 for (domain, partial) in &by_domain {
810 partial
811 .clone()
812 .finish()
813 .with_context(|| format!("domain: {domain}"))?;
814 }
815
816 for (provider, prov) in &by_provider {
817 prov.finish_params()
818 .with_context(|| format!("provider: {provider}"))?;
819 }
820
821 if options.aliased_site != CheckLevel::Ignore {
822 for (site, aliases) in site_aliases {
823 if aliases.len() > 1 {
824 collector.push(
825 options.aliased_site,
826 format!(
827 "multiple domain blocks alias to the same site: {site}: {}",
828 aliases.join(", ")
829 ),
830 );
831 }
832 }
833 }
834
835 if options.provider_overlap != CheckLevel::Ignore {
836 for domain in mx.keys() {
837 let mut matching_providers = vec![];
838 for (prov_name, prov) in &by_provider {
839 if prov.domain_matches(domain).await {
840 matching_providers.push(prov_name.to_string());
841 }
842 }
843 if !matching_providers.is_empty() {
844 collector.push(
845 options.provider_overlap,
846 format!(
847 "domain {domain} is also matched by provider(s): {}",
848 matching_providers.join(", ")
849 ),
850 );
851 }
852 }
853 }
854
855 let mut ctx = Sha256::new();
856 ctx.update("by_site");
857 for (site, entry) in &by_site {
858 ctx.update(site);
859 entry.hash_into(&mut ctx);
860 }
861 ctx.update("by_domain");
862 for (domain, entry) in &by_domain {
863 ctx.update(domain);
864 entry.hash_into(&mut ctx);
865 }
866 ctx.update("by_provider");
867 for (provider, prov) in &by_provider {
868 ctx.update(provider);
869 prov.hash_into(&mut ctx);
870 }
871 ctx.update("warnings");
872 for warn in &collector.warnings {
873 ctx.update(warn);
874 }
875 ctx.update("errors");
876 for err in &collector.errors {
877 ctx.update(err);
878 }
879 let hash = ctx.finalize();
880 let hash = data_encoding::HEXLOWER.encode(&hash);
881
882 Ok(Self {
883 inner: Arc::new(ShapingInner {
884 by_site,
885 by_domain,
886 by_provider,
887 warnings: collector.warnings,
888 errors: collector.errors,
889 hash,
890 }),
891 })
892 }
893
894 async fn get_egress_path_config(
895 &self,
896 domain: &str,
897 egress_source: &str,
898 site_name: &str,
899 ) -> PartialEntry {
900 self.inner
901 .get_egress_path_config(domain, egress_source, site_name)
902 .await
903 }
904
905 pub async fn get_egress_path_config_value(
906 &self,
907 domain: &str,
908 egress_source: &str,
909 site_name: &str,
910 ) -> anyhow::Result<serde_json::Value> {
911 let partial = self
912 .get_egress_path_config(domain, egress_source, site_name)
913 .await;
914 Ok(serde_json::to_value(&partial)?)
915 }
916
917 pub fn get_errors(&self) -> &[String] {
918 &self.inner.errors
919 }
920
921 pub fn get_warnings(&self) -> &[String] {
922 &self.inner.warnings
923 }
924
925 pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
926 self.inner.match_rules(record).await
927 }
928
929 pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
930 let mut result = BTreeMap::new();
931
932 for (site_name, site) in &self.inner.by_site {
933 for source_name in site.sources.keys() {
934 result
935 .entry(source_name.to_string())
936 .or_insert(vec![])
937 .push(format!("site:{site_name}"));
938 }
939 }
940 for (domain_name, domain) in &self.inner.by_domain {
941 for source_name in domain.sources.keys() {
942 result
943 .entry(source_name.to_string())
944 .or_insert(vec![])
945 .push(format!("domain:{domain_name}"));
946 }
947 }
948
949 result
950 }
951
952 pub fn hash(&self) -> String {
953 self.inner.hash.clone()
954 }
955}
956
957#[cfg(feature = "lua")]
958impl LuaUserData for Shaping {
959 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
960 mod_memoize::Memoized::impl_memoize(methods);
961 methods.add_async_method(
962 "get_egress_path_config",
963 |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
964 let params = this
965 .get_egress_path_config(&domain, &egress_source, &site_name)
966 .await;
967 lua.to_value_with(¶ms.params, serialize_options())
968 },
969 );
970
971 methods.add_method("get_errors", move |_lua, this, ()| {
972 let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
973 Ok(errors)
974 });
975
976 methods.add_method("get_warnings", move |_lua, this, ()| {
977 let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
978 Ok(warnings)
979 });
980
981 methods.add_method("get_referenced_sources", move |_lua, this, ()| {
982 Ok(this.get_referenced_sources())
983 });
984
985 methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
986 let record: JsonLogRecord = lua.from_value(record)?;
987 let rules = this.match_rules(&record).await.map_err(any_err)?;
988 let mut result = vec![];
989 for rule in rules {
990 result.push(lua.to_value(&rule)?);
991 }
992 Ok(result)
993 });
994
995 methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
996 }
997}
998
999#[derive(Default, Debug)]
1000pub struct MergedEntry {
1001 pub params: EgressPathConfig,
1002 pub sources: OrderMap<String, EgressPathConfig>,
1003 pub automation: Vec<Rule>,
1004}
1005
1006#[cfg(feature = "lua")]
1007#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1008struct ShapingFile {
1009 pub default: Option<PartialEntry>,
1010 #[serde(flatten, default)]
1011 pub domains: OrderMap<String, PartialEntry>,
1012 #[serde(default)]
1013 pub provider: OrderMap<String, ProviderEntry>,
1014}
1015
1016#[cfg(feature = "lua")]
1017#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1018struct PartialEntry {
1019 #[serde(skip)]
1020 pub domain_name: Option<String>,
1021
1022 #[serde(flatten)]
1023 pub params: toml::Table,
1024
1025 #[serde(default = "default_true")]
1026 pub mx_rollup: bool,
1027
1028 #[cfg(test)]
1032 #[serde(default)]
1033 pub _treat_domain_name_as_site_name: bool,
1034
1035 #[serde(default)]
1036 pub replace_base: bool,
1037
1038 #[serde(default)]
1039 pub automation: Vec<Rule>,
1040
1041 #[serde(default)]
1042 pub sources: OrderMap<String, toml::Table>,
1043}
1044
1045#[cfg(feature = "lua")]
1046#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1047pub struct ProviderEntry {
1048 #[serde(skip, default)]
1049 pub provider_name: String,
1050
1051 #[serde(default)]
1052 pub provider_connection_limit: Option<LimitSpec>,
1053
1054 #[serde(default)]
1055 pub provider_max_message_rate: Option<ThrottleSpec>,
1056
1057 #[serde(default, rename = "match")]
1058 pub matches: Vec<ProviderMatch>,
1059
1060 #[serde(default)]
1061 pub replace_base: bool,
1062
1063 #[serde(flatten)]
1064 pub params: toml::Table,
1065
1066 #[serde(default)]
1067 pub automation: Vec<Rule>,
1068
1069 #[serde(default)]
1070 pub sources: OrderMap<String, toml::Table>,
1071}
1072
1073#[cfg(feature = "lua")]
1074fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1075 let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1078 candidate.ends_with(suffix)
1082}
1083
1084#[cfg(feature = "lua")]
1085#[cfg(test)]
1086#[test]
1087fn test_suffix_matches() {
1088 assert!(suffix_matches("a", "a"));
1089 assert!(suffix_matches("foo.com", "foo.com"));
1090 assert!(!suffix_matches("foo.com", ".foo.com"));
1091 assert!(!suffix_matches("foo.com", "longer.com"));
1092 assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1093
1094 assert!(!suffix_matches("foo.Com", ".com"));
1098 assert!(!suffix_matches("foo.Cam", ".com"));
1099}
1100
1101#[cfg(feature = "lua")]
1102fn host_matches(candidate: &str, name: &str) -> bool {
1103 let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1106 candidate == name
1107}
1108
1109#[cfg(feature = "lua")]
1110#[cfg(test)]
1111#[test]
1112fn test_host_matches() {
1113 assert!(host_matches("foo.com", "foo.com"));
1114 assert!(host_matches("foo.com.", "foo.com"));
1115 assert!(!host_matches("foo.com", "notfoo.com"));
1116}
1117
1118#[cfg(feature = "lua")]
1119impl ProviderEntry {
1120 async fn domain_matches(&self, domain: &str) -> bool {
1121 let mut need_mx = false;
1124
1125 tracing::trace!(
1126 "ProviderEntry::domain_matches({domain}) vs {:?}",
1127 self.matches
1128 );
1129
1130 for rule in &self.matches {
1131 match rule {
1132 ProviderMatch::DomainSuffix(suffix) => {
1133 if suffix_matches(domain, suffix) {
1134 tracing::trace!("{domain} suffix matches {suffix}");
1135 return true;
1136 }
1137 }
1138 ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1139 need_mx = true;
1140 }
1141 }
1142 }
1143
1144 if !need_mx {
1145 return false;
1146 }
1147
1148 match MailExchanger::resolve(domain).await {
1150 Err(err) => {
1151 tracing::trace!(
1157 "Error resolving MX for {domain}: {err:#}. \
1158 Provider {} match rules will be ignored",
1159 self.provider_name
1160 );
1161 false
1162 }
1163 Ok(mx) => {
1164 tracing::trace!("Consider MXSuffix rules");
1165 for host in &mx.hosts {
1166 let mut matched = false;
1167
1168 for rule in &self.matches {
1169 match rule {
1170 ProviderMatch::MXSuffix(suffix) => {
1171 tracing::trace!("suffix={suffix} vs host {host}");
1176 if suffix_matches(host, suffix) {
1177 matched = true;
1178 break;
1179 }
1180 }
1181 ProviderMatch::HostName(name) => {
1182 if host_matches(host, name) {
1183 matched = true;
1184 break;
1185 }
1186 }
1187 ProviderMatch::DomainSuffix(_) => {}
1188 }
1189 }
1190
1191 if !matched {
1192 tracing::trace!("host didn't match any of these rules");
1193 return false;
1194 }
1195 }
1196
1197 true
1198 }
1199 }
1200 }
1201
1202 fn merge_from(&mut self, mut other: Self) {
1203 if other.replace_base {
1204 self.provider_connection_limit = other.provider_connection_limit;
1205 self.matches = other.matches;
1206 self.params = other.params;
1207 self.sources = other.sources;
1208 self.automation = other.automation;
1209 } else {
1210 if other.provider_connection_limit.is_some() {
1211 self.provider_connection_limit = other.provider_connection_limit;
1212 }
1213
1214 toml_table_merge_from(&mut self.params, &other.params);
1215
1216 for (source, tbl) in other.sources {
1217 match self.sources.get_mut(&source) {
1218 Some(existing) => {
1219 toml_table_merge_from(existing, &tbl);
1220 }
1221 None => {
1222 self.sources.insert(source, tbl);
1223 }
1224 }
1225 }
1226
1227 self.matches.append(&mut other.matches);
1228 self.automation.append(&mut other.automation);
1229 }
1230 }
1231
1232 fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1233 let mut implied = toml::Table::new();
1234 implied.insert(
1235 "provider_name".to_string(),
1236 toml::Value::String(self.provider_name.to_string()),
1237 );
1238
1239 if let Some(limit) = &self.provider_connection_limit {
1240 let mut limits = toml::Table::new();
1241 limits.insert(
1242 format!("shaping-provider-{}-{source}-limit", self.provider_name),
1243 toml::Value::String(limit.to_string()),
1244 );
1245 implied.insert(
1246 "additional_connection_limits".to_string(),
1247 toml::Value::Table(limits),
1248 );
1249 }
1250 if let Some(rate) = &self.provider_max_message_rate {
1251 let rate = rate.as_string();
1252 let mut limits = toml::Table::new();
1253 limits.insert(
1254 format!("shaping-provider-{}-{source}-rate", self.provider_name),
1255 rate.into(),
1256 );
1257 implied.insert(
1258 "additional_message_rate_throttles".to_string(),
1259 toml::Value::Table(limits),
1260 );
1261 }
1262
1263 if let Some(rate) = target.remove("provider_source_selection_rate") {
1264 let mut limits = toml::Table::new();
1265 limits.insert(
1266 format!(
1267 "shaping-provider-{}-{source}-selection-rate",
1268 self.provider_name
1269 ),
1270 rate,
1271 );
1272 implied.insert(
1273 "additional_source_selection_rates".to_string(),
1274 toml::Value::Table(limits),
1275 );
1276 }
1277
1278 toml_table_merge_from(target, &implied);
1279 }
1280
1281 fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1282 let provider_name = &self.provider_name;
1283
1284 let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1285 format!(
1286 "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1287 self.params
1288 )
1289 })?;
1290 let mut sources = OrderMap::new();
1291
1292 for (source, params) in &self.sources {
1293 let mut params = params.clone();
1294 params.remove("provider_source_selection_rate");
1300 sources.insert(
1301 source.clone(),
1302 EgressPathConfig::deserialize(params.clone()).with_context(|| {
1303 format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1304 })?,
1305 );
1306 }
1307
1308 Ok(MergedEntry {
1309 params,
1310 sources,
1311 automation: self.automation.clone(),
1312 })
1313 }
1314
1315 fn hash_into(&self, ctx: &mut Sha256) {
1316 ctx.update(&self.provider_name);
1317 ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1318 }
1319}
1320
1321#[cfg(feature = "lua")]
1322#[derive(Deserialize, Serialize, Debug, Clone)]
1323pub enum ProviderMatch {
1324 MXSuffix(String),
1325 DomainSuffix(String),
1326 HostName(String),
1327}
1328
1329#[cfg(feature = "lua")]
1330fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1331 fn is_mergeable(name: &str) -> bool {
1334 match name {
1335 "additional_connection_limits"
1336 | "additional_message_rate_throttles"
1337 | "additional_source_selection_rates" => true,
1338 _ => false,
1339 }
1340 }
1341
1342 for (k, v) in source {
1343 match (tbl.get_mut(k), v.as_table()) {
1344 (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1346 for (inner_k, inner_v) in v {
1347 existing.insert(inner_k.clone(), inner_v.clone());
1348 }
1349 }
1350 _ => {
1351 tbl.insert(k.clone(), v.clone());
1352 }
1353 }
1354 }
1355}
1356
1357#[cfg(feature = "lua")]
1358impl PartialEntry {
1359 fn merge_from(&mut self, mut other: Self) {
1360 if other.replace_base {
1361 self.params = other.params;
1362 self.automation = other.automation;
1363 self.sources = other.sources;
1364 } else {
1365 toml_table_merge_from(&mut self.params, &other.params);
1366
1367 for (source, tbl) in other.sources {
1368 match self.sources.get_mut(&source) {
1369 Some(existing) => {
1370 toml_table_merge_from(existing, &tbl);
1371 }
1372 None => {
1373 self.sources.insert(source, tbl);
1374 }
1375 }
1376 }
1377
1378 self.automation.append(&mut other.automation);
1379 }
1380 }
1381
1382 fn finish(self) -> anyhow::Result<MergedEntry> {
1383 let domain = self.domain_name.unwrap_or_default();
1384
1385 let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1386 format!(
1387 "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1388 self.params
1389 )
1390 })?;
1391 let mut sources = OrderMap::new();
1392
1393 for (source, params) in self.sources {
1394 sources.insert(
1395 source.clone(),
1396 EgressPathConfig::deserialize(params.clone()).with_context(|| {
1397 format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1398 })?,
1399 );
1400 }
1401
1402 Ok(MergedEntry {
1403 params,
1404 sources,
1405 automation: self.automation,
1406 })
1407 }
1408
1409 fn hash_into(&self, ctx: &mut Sha256) {
1410 if let Some(name) = self.domain_name.as_ref() {
1411 ctx.update(name)
1412 }
1413 ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1414 }
1415}
1416
1417fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1418where
1419 T: Deserialize<'de>,
1420 D: Deserializer<'de>,
1421{
1422 let result: Result<Vec<T>, _> =
1423 OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1424 match result {
1425 Ok(r) => Ok(r),
1426 Err(err) => Err(serde::de::Error::custom(format!(
1427 "{expecting}.\nThe underlying error message is:\n{err:#}"
1428 ))),
1429 }
1430}
1431
1432fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1433where
1434 D: Deserializer<'de>,
1435{
1436 one_or_many(
1437 deserializer,
1438 "\"action\" field expected either a single Action or an array of Actions",
1439 )
1440}
1441
1442fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1443where
1444 D: Deserializer<'de>,
1445{
1446 string_or_array(
1447 deserializer,
1448 "regex string or array of regex strings for field regex",
1449 )
1450}
1451
1452fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1453where
1454 T: Deserialize<'de> + TryFrom<String>,
1455 <T as TryFrom<String>>::Error: std::fmt::Debug,
1456 D: Deserializer<'de>,
1457{
1458 struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1464
1465 impl<'de, T> Visitor<'de> for StringOrArray<T>
1466 where
1467 T: Deserialize<'de> + TryFrom<String>,
1468 <T as TryFrom<String>>::Error: std::fmt::Debug,
1469 {
1470 type Value = Vec<T>;
1471
1472 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1473 formatter.write_str(self.1)
1474 }
1475
1476 fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1477 where
1478 E: serde::de::Error,
1479 {
1480 Ok(vec![
1481 T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1482 ])
1483 }
1484
1485 fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1486 where
1487 S: SeqAccess<'de>,
1488 {
1489 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1490 }
1491 }
1492
1493 deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1494}
1495
1496#[cfg(feature = "lua")]
1497fn default_true() -> bool {
1498 true
1499}
1500
1501#[cfg(feature = "lua")]
1502pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1503 let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1504
1505 shaping_mod.set(
1506 "load",
1507 lua.create_async_function(
1508 move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1509 let options = match options {
1510 Some(v) => lua.from_value(v)?,
1511 None => Default::default(),
1512 };
1513 let shaping = Shaping::merge_files(&paths, &options)
1514 .await
1515 .map_err(any_err)?;
1516 Ok(shaping)
1517 },
1518 )?,
1519 )?;
1520
1521 Ok(())
1522}
1523
1524#[cfg(test)]
1525mod test {
1526 use super::*;
1527 use kumo_log_types::RecordType;
1528 use rfc5321::Response;
1529 use std::io::Write;
1530 use tempfile::NamedTempFile;
1531 use uuid::Uuid;
1532
1533 async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1534 let mut files = vec![];
1535 let mut file_names = vec![];
1536
1537 for (i, content) in inputs.iter().enumerate() {
1538 let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1539 shaping_file.write_all(content.as_bytes()).unwrap();
1540 file_names.push(shaping_file.path().to_str().unwrap().to_string());
1541 files.push(shaping_file);
1542 }
1543
1544 Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1545 .await
1546 .unwrap()
1547 }
1548
1549 #[tokio::test]
1550 async fn test_merge_additional() {
1551 let shaping = make_shaping_configs(&[
1552 r#"
1553["example.com"]
1554mx_rollup = false
1555additional_connection_limits = {"first"=10}
1556 "#,
1557 r#"
1558["example.com"]
1559mx_rollup = false
1560additional_connection_limits = {"second"=32}
1561additional_message_rate_throttles = {"second"="100/hr"}
1562 "#,
1563 ])
1564 .await;
1565
1566 let resolved = shaping
1567 .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1568 .await
1569 .finish()
1570 .unwrap();
1571
1572 k9::snapshot!(
1573 resolved.params.additional_connection_limits,
1574 r#"
1575{
1576 "first": 10,
1577 "second": 32,
1578}
1579"#
1580 );
1581 k9::snapshot!(
1582 resolved.params.additional_message_rate_throttles,
1583 r#"
1584{
1585 "second": 100/h,
1586}
1587"#
1588 );
1589 }
1590
1591 #[tokio::test]
1592 async fn test_provider_multi_hostname() {
1593 let shaping = make_shaping_configs(&[r#"
1594[provider."yahoo"]
1595match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1596enable_tls = "Required"
1597 "#])
1598 .await;
1599
1600 let resolved = shaping
1601 .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1602 .await
1603 .finish()
1604 .unwrap();
1605
1606 k9::assert_equal!(
1607 resolved.params.enable_tls,
1608 crate::egress_path::Tls::Required
1609 );
1610 k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1611 }
1612
1613 #[tokio::test]
1614 async fn test_provider_multi_suffix() {
1615 let shaping = make_shaping_configs(&[r#"
1616[provider."yahoo"]
1617match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1618enable_tls = "Required"
1619 "#])
1620 .await;
1621
1622 let resolved = shaping
1623 .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1624 .await
1625 .finish()
1626 .unwrap();
1627
1628 k9::assert_equal!(
1629 resolved.params.enable_tls,
1630 crate::egress_path::Tls::Required
1631 );
1632 k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1633 }
1634
1635 #[tokio::test]
1636 async fn test_provider() {
1637 let shaping = make_shaping_configs(&[r#"
1638[provider."Office 365"]
1639match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1640enable_tls = "Required"
1641provider_connection_limit = 10
1642provider_max_message_rate = "120/s"
1643
1644[provider."Office 365".sources."new-source"]
1645provider_source_selection_rate = "500/d,max_burst=1"
1646 "#])
1647 .await;
1648
1649 let resolved = shaping
1650 .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1651 .await
1652 .finish()
1653 .unwrap();
1654
1655 k9::assert_equal!(
1656 resolved.params.enable_tls,
1657 crate::egress_path::Tls::Required
1658 );
1659 k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1660
1661 k9::snapshot!(
1662 resolved.params.additional_connection_limits,
1663 r#"
1664{
1665 "shaping-provider-Office 365-invalid.source-limit": 10,
1666}
1667"#
1668 );
1669 k9::snapshot!(
1670 resolved.params.additional_message_rate_throttles,
1671 r#"
1672{
1673 "shaping-provider-Office 365-invalid.source-rate": 120/s,
1674}
1675"#
1676 );
1677 assert!(resolved.params.source_selection_rate.is_none());
1678 assert!(resolved.params.additional_source_selection_rates.is_empty());
1679
1680 let resolved = shaping
1681 .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1682 .await
1683 .finish()
1684 .unwrap();
1685 assert!(resolved.params.source_selection_rate.is_none());
1686 k9::snapshot!(
1687 resolved.params.additional_source_selection_rates,
1688 r#"
1689{
1690 "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1691}
1692"#
1693 );
1694 }
1695
1696 #[tokio::test]
1697 async fn test_rule_matching() {
1698 let shaping = make_shaping_configs(&[r#"
1699[["default".automation]]
1700regex="default"
1701action = {SetConfig={name="connection_limit", value=1}}
1702duration = "1hr"
1703
1704["fake.site"]
1705_treat_domain_name_as_site_name = true
1706
1707[["fake.site".automation]]
1708regex="fake_rollup"
1709action = {SetConfig={name="connection_limit", value=2}}
1710duration = "1hr"
1711
1712["woot.provider"]
1713mx_rollup = false
1714
1715[["woot.provider".automation]]
1716regex="woot_domain"
1717action = {SetConfig={name="connection_limit", value=2}}
1718duration = "1hr"
1719
1720[provider."provider"]
1721match=[{DomainSuffix=".provider"}]
1722
1723[[provider."provider".automation]]
1724regex="provider"
1725action = {SetConfig={name="connection_limit", value=3}}
1726duration = "1hr"
1727match_internal = true
1728
1729"#])
1730 .await;
1731
1732 eprintln!("{:?}", shaping.inner.warnings);
1733
1734 fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1735 JsonLogRecord {
1736 kind: RecordType::TransientFailure,
1737 id: String::new(),
1738 sender: String::new(),
1739 recipient: vec![recipient.to_string()],
1740 queue: String::new(),
1741 site: site.to_string(),
1742 size: 0,
1743 response: Response {
1744 code: 400,
1745 command: None,
1746 enhanced_code: None,
1747 content: content.to_string(),
1748 },
1749 peer_address: None,
1750 timestamp: Default::default(),
1751 created: Default::default(),
1752 num_attempts: 1,
1753 bounce_classification: Default::default(),
1754 egress_pool: None,
1755 egress_source: None,
1756 source_address: None,
1757 feedback_report: None,
1758 meta: Default::default(),
1759 headers: Default::default(),
1760 delivery_protocol: None,
1761 reception_protocol: None,
1762 nodeid: Uuid::default(),
1763 tls_cipher: None,
1764 tls_protocol_version: None,
1765 tls_peer_subject_name: None,
1766 provider_name: None,
1767 session_id: None,
1768 }
1769 }
1770
1771 let matches = shaping
1772 .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1773 .await
1774 .unwrap();
1775 k9::assert_equal!(
1776 matches[0].regex[0].to_string(),
1777 "default",
1778 "matches against default automation rule"
1779 );
1780
1781 let matches = shaping
1782 .match_rules(&make_record(
1783 "KumoMTA internal: default",
1784 "user@example.com",
1785 "dummy_site",
1786 ))
1787 .await
1788 .unwrap();
1789 assert!(matches.is_empty(), "internal bounce should not match");
1790
1791 let matches = shaping
1792 .match_rules(&make_record(
1793 "woot_domain",
1794 "user@woot.provider",
1795 "dummy_site",
1796 ))
1797 .await
1798 .unwrap();
1799 k9::assert_equal!(
1800 matches[0].regex[0].to_string(),
1801 "woot_domain",
1802 "matches against domain rule with mx_rollup=false"
1803 );
1804
1805 let matches = shaping
1806 .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1807 .await
1808 .unwrap();
1809 k9::assert_equal!(
1810 matches[0].regex[0].to_string(),
1811 "fake_rollup",
1812 "matches against domain rule with mx_rollup=true"
1813 );
1814
1815 let matches = shaping
1816 .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1817 .await
1818 .unwrap();
1819 k9::assert_equal!(
1820 matches[0].regex[0].to_string(),
1821 "provider",
1822 "matches against provider rule"
1823 );
1824
1825 let matches = shaping
1826 .match_rules(&make_record(
1827 "KumoMTA internal: provider",
1828 "user@woot.provider",
1829 "dummy_site",
1830 ))
1831 .await
1832 .unwrap();
1833 k9::assert_equal!(
1834 matches[0].regex[0].to_string(),
1835 "provider",
1836 "internal response matches against provider rule"
1837 );
1838 }
1839
1840 #[tokio::test]
1841 async fn test_defaults() {
1842 let shaping = make_shaping_configs(&[
1843 r#"
1844["default"]
1845connection_limit = 10
1846max_connection_rate = "100/min"
1847max_deliveries_per_connection = 100
1848max_message_rate = "100/s"
1849idle_timeout = "60s"
1850data_timeout = "30s"
1851data_dot_timeout = "60s"
1852enable_tls = "Opportunistic"
1853consecutive_connection_failures_before_delay = 100
1854
1855[["default".automation]]
1856regex=[
1857 '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1858 '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1859 '/has been temporarily rate limited due to IP reputation/',
1860 '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1861 '/Server busy\. Please try again later from/'
1862]
1863action = [
1864 {SetConfig={name="max_message_rate", value="1/minute"}},
1865 {SetConfig={name="connection_limit", value=1}}
1866]
1867duration = "90m"
1868
1869[["default".automation]]
1870regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1871action = {SetConfig={name="enable_tls", value="Disabled"}}
1872duration = "30 days"
1873
1874["gmail.com"]
1875max_deliveries_per_connection = 50
1876connection_limit = 5
1877enable_tls = "Required"
1878consecutive_connection_failures_before_delay = 5
1879
1880["yahoo.com"]
1881max_deliveries_per_connection = 20
1882
1883[["yahoo.com".automation]]
1884regex = "\\[TS04\\]"
1885action = "Suspend"
1886duration = "2 hours"
1887
1888["comcast.net"]
1889connection_limit = 25
1890max_deliveries_per_connection = 250
1891enable_tls = "Required"
1892idle_timeout = "30s"
1893consecutive_connection_failures_before_delay = 24
1894
1895["mail.com"]
1896max_deliveries_per_connection = 100
1897
1898["orange.fr"]
1899connection_limit = 3
1900
1901["smtp.mailgun.com"]
1902connection_limit = 7000
1903max_deliveries_per_connection = 3
1904
1905["example.com"]
1906mx_rollup = false
1907max_deliveries_per_connection = 100
1908connection_limit = 3
1909
1910["example.com".sources."my source name"]
1911connection_limit = 5
1912 "#,
1913 ])
1914 .await;
1915
1916 let default = shaping
1917 .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1918 .await
1919 .finish()
1920 .unwrap();
1921 k9::snapshot!(
1922 default,
1923 r#"
1924MergedEntry {
1925 params: EgressPathConfig {
1926 connection_limit: 10,
1927 additional_connection_limits: {},
1928 enable_tls: Opportunistic,
1929 enable_mta_sts: true,
1930 enable_dane: false,
1931 enable_pipelining: true,
1932 enable_rset: true,
1933 tls_prefer_openssl: false,
1934 tls_certificate: None,
1935 tls_private_key: None,
1936 openssl_cipher_list: None,
1937 openssl_cipher_suites: None,
1938 openssl_options: None,
1939 rustls_cipher_suites: [],
1940 client_timeouts: SmtpClientTimeouts {
1941 connect_timeout: 60s,
1942 banner_timeout: 60s,
1943 ehlo_timeout: 300s,
1944 mail_from_timeout: 300s,
1945 rcpt_to_timeout: 300s,
1946 data_timeout: 30s,
1947 data_dot_timeout: 60s,
1948 rset_timeout: 5s,
1949 idle_timeout: 60s,
1950 starttls_timeout: 5s,
1951 auth_timeout: 60s,
1952 },
1953 system_shutdown_timeout: None,
1954 max_ready: 1024,
1955 consecutive_connection_failures_before_delay: 100,
1956 smtp_port: 25,
1957 smtp_auth_plain_username: None,
1958 smtp_auth_plain_password: None,
1959 allow_smtp_auth_plain_without_tls: false,
1960 max_message_rate: Some(
1961 100/s,
1962 ),
1963 additional_message_rate_throttles: {},
1964 source_selection_rate: None,
1965 additional_source_selection_rates: {},
1966 max_connection_rate: Some(
1967 100/m,
1968 ),
1969 max_deliveries_per_connection: 100,
1970 max_recipients_per_batch: 100,
1971 prohibited_hosts: {
1972 "127.0.0.0/8",
1973 "::1",
1974 },
1975 skip_hosts: {},
1976 ehlo_domain: None,
1977 aggressive_connection_opening: false,
1978 refresh_interval: 60s,
1979 refresh_strategy: Ttl,
1980 dispatcher_wakeup_strategy: Aggressive,
1981 maintainer_wakeup_strategy: Aggressive,
1982 provider_name: None,
1983 remember_broken_tls: None,
1984 opportunistic_tls_reconnect_on_failed_handshake: false,
1985 use_lmtp: false,
1986 reconnect_strategy: ConnectNextHost,
1987 readyq_pool_name: None,
1988 low_memory_reduction_policy: ShrinkDataAndMeta,
1989 no_memory_reduction_policy: ShrinkDataAndMeta,
1990 try_next_host_on_transport_error: false,
1991 ignore_8bit_checks: false,
1992 },
1993 sources: {},
1994 automation: [
1995 Rule {
1996 regex: [
1997 Regex(
1998 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
1999 ),
2000 Regex(
2001 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2002 ),
2003 Regex(
2004 /has been temporarily rate limited due to IP reputation/,
2005 ),
2006 Regex(
2007 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2008 ),
2009 Regex(
2010 /Server busy\. Please try again later from/,
2011 ),
2012 ],
2013 action: [
2014 SetConfig(
2015 EgressPathConfigValue {
2016 name: "max_message_rate",
2017 value: HashableTomlValue {
2018 value: String(
2019 "1/minute",
2020 ),
2021 },
2022 },
2023 ),
2024 SetConfig(
2025 EgressPathConfigValue {
2026 name: "connection_limit",
2027 value: HashableTomlValue {
2028 value: Integer(
2029 1,
2030 ),
2031 },
2032 },
2033 ),
2034 ],
2035 trigger: Immediate,
2036 duration: 5400s,
2037 was_rollup: false,
2038 match_internal: false,
2039 },
2040 Rule {
2041 regex: [
2042 Regex(
2043 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2044 ),
2045 ],
2046 action: [
2047 SetConfig(
2048 EgressPathConfigValue {
2049 name: "enable_tls",
2050 value: HashableTomlValue {
2051 value: String(
2052 "Disabled",
2053 ),
2054 },
2055 },
2056 ),
2057 ],
2058 trigger: Immediate,
2059 duration: 2592000s,
2060 was_rollup: false,
2061 match_internal: false,
2062 },
2063 ],
2064}
2065"#
2066 );
2067
2068 let example_com = shaping
2069 .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2070 .await
2071 .finish()
2072 .unwrap();
2073 k9::snapshot!(
2074 example_com,
2075 r#"
2076MergedEntry {
2077 params: EgressPathConfig {
2078 connection_limit: 3,
2079 additional_connection_limits: {},
2080 enable_tls: Opportunistic,
2081 enable_mta_sts: true,
2082 enable_dane: false,
2083 enable_pipelining: true,
2084 enable_rset: true,
2085 tls_prefer_openssl: false,
2086 tls_certificate: None,
2087 tls_private_key: None,
2088 openssl_cipher_list: None,
2089 openssl_cipher_suites: None,
2090 openssl_options: None,
2091 rustls_cipher_suites: [],
2092 client_timeouts: SmtpClientTimeouts {
2093 connect_timeout: 60s,
2094 banner_timeout: 60s,
2095 ehlo_timeout: 300s,
2096 mail_from_timeout: 300s,
2097 rcpt_to_timeout: 300s,
2098 data_timeout: 30s,
2099 data_dot_timeout: 60s,
2100 rset_timeout: 5s,
2101 idle_timeout: 60s,
2102 starttls_timeout: 5s,
2103 auth_timeout: 60s,
2104 },
2105 system_shutdown_timeout: None,
2106 max_ready: 1024,
2107 consecutive_connection_failures_before_delay: 100,
2108 smtp_port: 25,
2109 smtp_auth_plain_username: None,
2110 smtp_auth_plain_password: None,
2111 allow_smtp_auth_plain_without_tls: false,
2112 max_message_rate: Some(
2113 100/s,
2114 ),
2115 additional_message_rate_throttles: {},
2116 source_selection_rate: None,
2117 additional_source_selection_rates: {},
2118 max_connection_rate: Some(
2119 100/m,
2120 ),
2121 max_deliveries_per_connection: 100,
2122 max_recipients_per_batch: 100,
2123 prohibited_hosts: {
2124 "127.0.0.0/8",
2125 "::1",
2126 },
2127 skip_hosts: {},
2128 ehlo_domain: None,
2129 aggressive_connection_opening: false,
2130 refresh_interval: 60s,
2131 refresh_strategy: Ttl,
2132 dispatcher_wakeup_strategy: Aggressive,
2133 maintainer_wakeup_strategy: Aggressive,
2134 provider_name: None,
2135 remember_broken_tls: None,
2136 opportunistic_tls_reconnect_on_failed_handshake: false,
2137 use_lmtp: false,
2138 reconnect_strategy: ConnectNextHost,
2139 readyq_pool_name: None,
2140 low_memory_reduction_policy: ShrinkDataAndMeta,
2141 no_memory_reduction_policy: ShrinkDataAndMeta,
2142 try_next_host_on_transport_error: false,
2143 ignore_8bit_checks: false,
2144 },
2145 sources: {
2146 "my source name": EgressPathConfig {
2147 connection_limit: 5,
2148 additional_connection_limits: {},
2149 enable_tls: Opportunistic,
2150 enable_mta_sts: true,
2151 enable_dane: false,
2152 enable_pipelining: true,
2153 enable_rset: true,
2154 tls_prefer_openssl: false,
2155 tls_certificate: None,
2156 tls_private_key: None,
2157 openssl_cipher_list: None,
2158 openssl_cipher_suites: None,
2159 openssl_options: None,
2160 rustls_cipher_suites: [],
2161 client_timeouts: SmtpClientTimeouts {
2162 connect_timeout: 60s,
2163 banner_timeout: 60s,
2164 ehlo_timeout: 300s,
2165 mail_from_timeout: 300s,
2166 rcpt_to_timeout: 300s,
2167 data_timeout: 300s,
2168 data_dot_timeout: 300s,
2169 rset_timeout: 5s,
2170 idle_timeout: 5s,
2171 starttls_timeout: 5s,
2172 auth_timeout: 60s,
2173 },
2174 system_shutdown_timeout: None,
2175 max_ready: 1024,
2176 consecutive_connection_failures_before_delay: 100,
2177 smtp_port: 25,
2178 smtp_auth_plain_username: None,
2179 smtp_auth_plain_password: None,
2180 allow_smtp_auth_plain_without_tls: false,
2181 max_message_rate: None,
2182 additional_message_rate_throttles: {},
2183 source_selection_rate: None,
2184 additional_source_selection_rates: {},
2185 max_connection_rate: None,
2186 max_deliveries_per_connection: 1024,
2187 max_recipients_per_batch: 100,
2188 prohibited_hosts: {
2189 "127.0.0.0/8",
2190 "::1",
2191 },
2192 skip_hosts: {},
2193 ehlo_domain: None,
2194 aggressive_connection_opening: false,
2195 refresh_interval: 60s,
2196 refresh_strategy: Ttl,
2197 dispatcher_wakeup_strategy: Aggressive,
2198 maintainer_wakeup_strategy: Aggressive,
2199 provider_name: None,
2200 remember_broken_tls: None,
2201 opportunistic_tls_reconnect_on_failed_handshake: false,
2202 use_lmtp: false,
2203 reconnect_strategy: ConnectNextHost,
2204 readyq_pool_name: None,
2205 low_memory_reduction_policy: ShrinkDataAndMeta,
2206 no_memory_reduction_policy: ShrinkDataAndMeta,
2207 try_next_host_on_transport_error: false,
2208 ignore_8bit_checks: false,
2209 },
2210 },
2211 automation: [
2212 Rule {
2213 regex: [
2214 Regex(
2215 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2216 ),
2217 Regex(
2218 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2219 ),
2220 Regex(
2221 /has been temporarily rate limited due to IP reputation/,
2222 ),
2223 Regex(
2224 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2225 ),
2226 Regex(
2227 /Server busy\. Please try again later from/,
2228 ),
2229 ],
2230 action: [
2231 SetConfig(
2232 EgressPathConfigValue {
2233 name: "max_message_rate",
2234 value: HashableTomlValue {
2235 value: String(
2236 "1/minute",
2237 ),
2238 },
2239 },
2240 ),
2241 SetConfig(
2242 EgressPathConfigValue {
2243 name: "connection_limit",
2244 value: HashableTomlValue {
2245 value: Integer(
2246 1,
2247 ),
2248 },
2249 },
2250 ),
2251 ],
2252 trigger: Immediate,
2253 duration: 5400s,
2254 was_rollup: false,
2255 match_internal: false,
2256 },
2257 Rule {
2258 regex: [
2259 Regex(
2260 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2261 ),
2262 ],
2263 action: [
2264 SetConfig(
2265 EgressPathConfigValue {
2266 name: "enable_tls",
2267 value: HashableTomlValue {
2268 value: String(
2269 "Disabled",
2270 ),
2271 },
2272 },
2273 ),
2274 ],
2275 trigger: Immediate,
2276 duration: 2592000s,
2277 was_rollup: false,
2278 match_internal: false,
2279 },
2280 ],
2281}
2282"#
2283 );
2284
2285 let yahoo_com = shaping
2288 .get_egress_path_config(
2289 "yahoo.com",
2290 "invalid.source",
2291 "(mta5|mta6|mta7).am0.yahoodns.net",
2292 )
2293 .await
2294 .finish()
2295 .unwrap();
2296 k9::snapshot!(
2297 yahoo_com,
2298 r#"
2299MergedEntry {
2300 params: EgressPathConfig {
2301 connection_limit: 10,
2302 additional_connection_limits: {},
2303 enable_tls: Opportunistic,
2304 enable_mta_sts: true,
2305 enable_dane: false,
2306 enable_pipelining: true,
2307 enable_rset: true,
2308 tls_prefer_openssl: false,
2309 tls_certificate: None,
2310 tls_private_key: None,
2311 openssl_cipher_list: None,
2312 openssl_cipher_suites: None,
2313 openssl_options: None,
2314 rustls_cipher_suites: [],
2315 client_timeouts: SmtpClientTimeouts {
2316 connect_timeout: 60s,
2317 banner_timeout: 60s,
2318 ehlo_timeout: 300s,
2319 mail_from_timeout: 300s,
2320 rcpt_to_timeout: 300s,
2321 data_timeout: 30s,
2322 data_dot_timeout: 60s,
2323 rset_timeout: 5s,
2324 idle_timeout: 60s,
2325 starttls_timeout: 5s,
2326 auth_timeout: 60s,
2327 },
2328 system_shutdown_timeout: None,
2329 max_ready: 1024,
2330 consecutive_connection_failures_before_delay: 100,
2331 smtp_port: 25,
2332 smtp_auth_plain_username: None,
2333 smtp_auth_plain_password: None,
2334 allow_smtp_auth_plain_without_tls: false,
2335 max_message_rate: Some(
2336 100/s,
2337 ),
2338 additional_message_rate_throttles: {},
2339 source_selection_rate: None,
2340 additional_source_selection_rates: {},
2341 max_connection_rate: Some(
2342 100/m,
2343 ),
2344 max_deliveries_per_connection: 20,
2345 max_recipients_per_batch: 100,
2346 prohibited_hosts: {
2347 "127.0.0.0/8",
2348 "::1",
2349 },
2350 skip_hosts: {},
2351 ehlo_domain: None,
2352 aggressive_connection_opening: false,
2353 refresh_interval: 60s,
2354 refresh_strategy: Ttl,
2355 dispatcher_wakeup_strategy: Aggressive,
2356 maintainer_wakeup_strategy: Aggressive,
2357 provider_name: None,
2358 remember_broken_tls: None,
2359 opportunistic_tls_reconnect_on_failed_handshake: false,
2360 use_lmtp: false,
2361 reconnect_strategy: ConnectNextHost,
2362 readyq_pool_name: None,
2363 low_memory_reduction_policy: ShrinkDataAndMeta,
2364 no_memory_reduction_policy: ShrinkDataAndMeta,
2365 try_next_host_on_transport_error: false,
2366 ignore_8bit_checks: false,
2367 },
2368 sources: {},
2369 automation: [
2370 Rule {
2371 regex: [
2372 Regex(
2373 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2374 ),
2375 Regex(
2376 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2377 ),
2378 Regex(
2379 /has been temporarily rate limited due to IP reputation/,
2380 ),
2381 Regex(
2382 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2383 ),
2384 Regex(
2385 /Server busy\. Please try again later from/,
2386 ),
2387 ],
2388 action: [
2389 SetConfig(
2390 EgressPathConfigValue {
2391 name: "max_message_rate",
2392 value: HashableTomlValue {
2393 value: String(
2394 "1/minute",
2395 ),
2396 },
2397 },
2398 ),
2399 SetConfig(
2400 EgressPathConfigValue {
2401 name: "connection_limit",
2402 value: HashableTomlValue {
2403 value: Integer(
2404 1,
2405 ),
2406 },
2407 },
2408 ),
2409 ],
2410 trigger: Immediate,
2411 duration: 5400s,
2412 was_rollup: false,
2413 match_internal: false,
2414 },
2415 Rule {
2416 regex: [
2417 Regex(
2418 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2419 ),
2420 ],
2421 action: [
2422 SetConfig(
2423 EgressPathConfigValue {
2424 name: "enable_tls",
2425 value: HashableTomlValue {
2426 value: String(
2427 "Disabled",
2428 ),
2429 },
2430 },
2431 ),
2432 ],
2433 trigger: Immediate,
2434 duration: 2592000s,
2435 was_rollup: false,
2436 match_internal: false,
2437 },
2438 Rule {
2439 regex: [
2440 Regex(
2441 \[TS04\],
2442 ),
2443 ],
2444 action: [
2445 Suspend,
2446 ],
2447 trigger: Immediate,
2448 duration: 7200s,
2449 was_rollup: false,
2450 match_internal: false,
2451 },
2452 ],
2453}
2454"#
2455 );
2456 }
2457
2458 #[tokio::test]
2459 async fn test_load_default_shaping_toml() {
2460 Shaping::merge_files(
2461 &["../../assets/policy-extras/shaping.toml".into()],
2462 &ShapingMergeOptions::default(),
2463 )
2464 .await
2465 .unwrap();
2466 }
2467}