1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::parser::{Command, Domain, EsmtpParameter, ForwardPath, ReversePath};
4use crate::{AsyncReadAndWrite, BoxedAsyncReadAndWrite};
5use bstr::ByteSlice;
6use hickory_proto::rr::rdata::TLSA;
7use memchr::memmem::Finder;
8use nom_utils::DomainString;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use openssl;
25pub use tokio_rustls;
26
27const MAX_LINE_LEN: usize = 4096;
28
29#[derive(Error, Debug, Clone)]
30pub enum ClientError {
31 #[error("response is not UTF8")]
32 Utf8(#[from] std::string::FromUtf8Error),
33 #[error("Malformed Response: {0}")]
34 MalformedResponseLine(String),
35 #[error("Response line is too long")]
36 ResponseTooLong,
37 #[error("Not connected")]
38 NotConnected,
39 #[error("Command rejected {0:?}")]
40 Rejected(Response),
41 #[error("Commands rejected {0:?}")]
42 RejectedBatch(Vec<Response>),
43 #[error("STARTTLS: {0} is not a valid DNS name")]
44 InvalidDnsName(String),
45 #[error("Invalid client certificate configured: {error:?}")]
46 FailedToBuildConnector { error: String },
47 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
48 TimeOutResponse {
49 command: Option<Command>,
50 duration: Duration,
51 },
52 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
53 TimeOutRequest {
54 commands: Vec<Command>,
55 duration: Duration,
56 },
57 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
58 ReadError {
59 command: Option<Command>,
60 error: String,
61 partial: String,
62 },
63 #[error("Error {error} flushing send buffer")]
64 FlushError { error: String },
65 #[error("Error {error} writing {}", self.command_for_err())]
66 WriteError {
67 commands: Vec<Command>,
68 error: String,
69 },
70 #[error("Timed Out sending message payload data")]
71 TimeOutData,
72 #[error("SSL Error: {0}")]
73 SslErrorStack(#[from] openssl::error::ErrorStack),
74 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79 pub fn command(&self) -> Option<String> {
81 match self {
82 Self::TimeOutResponse {
83 command: Some(command),
84 ..
85 }
86 | Self::ReadError {
87 command: Some(command),
88 ..
89 } => Some(command.encode().to_string()),
90 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91 if !commands.is_empty() =>
92 {
93 let s: String = commands
94 .iter()
95 .map(|cmd| cmd.encode().to_string())
96 .collect();
97 Some(s)
98 }
99 _ => None,
100 }
101 }
102
103 fn command_for_err(&self) -> String {
106 self.command()
107 .map(|cmd| cmd.replace("\r\n", ""))
108 .unwrap_or_else(|| "NONE".to_string())
109 }
110
111 pub fn was_due_to_message(&self) -> bool {
123 match self {
124 Self::Utf8(_)
125 | Self::MalformedResponseLine(_)
126 | Self::ResponseTooLong
127 | Self::NotConnected
128 | Self::InvalidDnsName(_)
129 | Self::FailedToBuildConnector { .. }
130 | Self::TimeOutResponse { .. }
131 | Self::TimeOutRequest { .. }
132 | Self::ReadError { .. }
133 | Self::FlushError { .. }
134 | Self::WriteError { .. }
135 | Self::TimeOutData
136 | Self::SslErrorStack(_)
137 | Self::NoUsableDaneTlsa { .. } => false,
138 Self::Rejected(response) => response.was_due_to_message(),
139 Self::RejectedBatch(responses) => match responses.len() {
140 1 => responses[0].was_due_to_message(),
141 _ => false,
142 },
143 }
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct EsmtpCapability {
149 pub name: String,
150 pub param: Option<String>,
151}
152
153#[derive(Clone, Debug)]
154pub enum SmtpClientTraceEvent {
155 Closed,
156 Read(Vec<u8>),
157 Write(String),
158 Diagnostic {
159 level: tracing::Level,
160 message: String,
161 },
162}
163
164pub trait DeferredTracer {
165 fn trace(&self) -> SmtpClientTraceEvent;
166}
167
168pub trait SmtpClientTracer: std::fmt::Debug {
169 fn trace_event(&self, event: SmtpClientTraceEvent);
170 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
171}
172
173struct WriteTracer<'a> {
175 data: &'a str,
176}
177impl DeferredTracer for WriteTracer<'_> {
178 fn trace(&self) -> SmtpClientTraceEvent {
179 SmtpClientTraceEvent::Write(self.data.to_string())
180 }
181}
182impl<'a> WriteTracer<'a> {
183 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
184 tracer.lazy_trace(&Self { data });
185 }
186}
187
188struct BinWriteTracer<'a> {
189 data: &'a [u8],
190}
191impl DeferredTracer for BinWriteTracer<'_> {
192 fn trace(&self) -> SmtpClientTraceEvent {
193 let data = String::from_utf8_lossy(self.data).to_string();
194 SmtpClientTraceEvent::Write(data)
195 }
196}
197impl<'a> BinWriteTracer<'a> {
198 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
199 tracer.lazy_trace(&Self { data });
200 }
201}
202
203struct ReadTracer<'a> {
206 data: &'a [u8],
207}
208impl DeferredTracer for ReadTracer<'_> {
209 fn trace(&self) -> SmtpClientTraceEvent {
210 SmtpClientTraceEvent::Read(self.data.to_vec())
211 }
212}
213
214#[derive(Debug)]
215pub struct SmtpClient {
216 socket: Option<BoxedAsyncReadAndWrite>,
217 hostname: String,
218 capabilities: HashMap<String, EsmtpCapability>,
219 read_buffer: Vec<u8>,
220 timeouts: SmtpClientTimeouts,
221 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
222 use_rset: bool,
223 enable_rset: bool,
224 enable_pipelining: bool,
225 ignore_8bit_checks: bool,
226}
227
228fn extract_hostname(hostname: &str) -> &str {
229 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
231 let hostname = if fields.len() == 2 {
232 fields[1]
233 } else {
234 hostname
235 };
236
237 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
238 &hostname[1..hostname.len() - 1]
239 } else {
240 hostname
241 };
242
243 hostname.strip_suffix(".").unwrap_or(hostname)
245}
246
247impl SmtpClient {
248 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
249 addr: A,
250 timeouts: SmtpClientTimeouts,
251 ) -> std::io::Result<Self> {
252 let stream = TcpStream::connect(addr.clone()).await?;
253 stream.set_nodelay(true)?;
255 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
256 }
257
258 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
259 stream: S,
260 peer_hostname: H,
261 timeouts: SmtpClientTimeouts,
262 ) -> Self {
263 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
264
265 Self {
266 socket: Some(Box::new(stream)),
267 hostname,
268 capabilities: HashMap::new(),
269 read_buffer: Vec::with_capacity(1024),
270 timeouts,
271 tracer: None,
272 use_rset: false,
273 enable_rset: false,
274 enable_pipelining: false,
275 ignore_8bit_checks: false,
276 }
277 }
278
279 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
281 self.ignore_8bit_checks = enable;
282 }
283
284 pub fn is_connected(&self) -> bool {
285 self.socket.is_some()
286 }
287
288 pub fn set_enable_rset(&mut self, enable: bool) {
289 self.enable_rset = enable;
290 }
291
292 pub fn set_enable_pipelining(&mut self, enable: bool) {
293 self.enable_pipelining = enable;
294 }
295
296 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
297 self.tracer.replace(tracer);
298 }
299
300 pub fn timeouts(&self) -> &SmtpClientTimeouts {
301 &self.timeouts
302 }
303
304 async fn read_line(
305 &mut self,
306 timeout_duration: Duration,
307 cmd: Option<&Command>,
308 ) -> Result<String, ClientError> {
309 let mut too_long = false;
310 loop {
311 let mut iter = self.read_buffer.iter().enumerate();
312 while let Some((i, &b)) = iter.next() {
313 if b != b'\r' {
314 continue;
315 }
316 if let Some((_, b'\n')) = iter.next() {
317 if too_long {
318 self.read_buffer.drain(0..i + 2);
319
320 if let Some(tracer) = &self.tracer {
321 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
322 level: Level::ERROR,
323 message: "Response too long".to_string(),
324 });
325 }
326
327 return Err(ClientError::ResponseTooLong);
328 }
329
330 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
331 self.read_buffer.drain(0..i + 2);
332 return Ok(line?);
333 }
334 }
335 if self.read_buffer.len() > MAX_LINE_LEN {
336 self.read_buffer.clear();
337 too_long = true;
338 }
339
340 let mut data = [0u8; MAX_LINE_LEN];
342 let size = match self.socket.as_mut() {
343 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
344 Ok(Ok(size)) => size,
345 Ok(Err(err)) => {
346 self.socket.take();
347 if let Some(tracer) = &self.tracer {
348 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
349 level: Level::ERROR,
350 message: format!("Error during read: {err:#}"),
351 });
352 tracer.trace_event(SmtpClientTraceEvent::Closed);
353 }
354 return Err(ClientError::ReadError {
355 command: cmd.cloned(),
356 error: format!("{err:#}"),
357 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
358 });
359 }
360 Err(_) => {
361 if timeout_duration != Duration::ZERO {
366 self.socket.take();
367 }
368 if let Some(tracer) = &self.tracer {
369 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
370 level: Level::ERROR,
371 message: format!("Read Timeout after {timeout_duration:?}"),
372 });
373 if self.socket.is_none() {
374 tracer.trace_event(SmtpClientTraceEvent::Closed);
375 }
376 }
377 return Err(ClientError::TimeOutResponse {
378 command: cmd.cloned(),
379 duration: timeout_duration,
380 });
381 }
382 },
383 None => {
384 return Err(ClientError::ReadError {
385 command: cmd.cloned(),
386 error: "the socket was closed in response to an earlier issue".to_string(),
387 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
388 });
389 }
390 };
391 if size == 0 {
392 self.socket.take();
393 if let Some(tracer) = &self.tracer {
394 tracer.trace_event(SmtpClientTraceEvent::Closed);
395 }
396 return Err(ClientError::ReadError {
397 command: cmd.cloned(),
398 error: "Connection closed by peer".to_string(),
399 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
400 });
401 }
402 if let Some(tracer) = &self.tracer {
403 tracer.lazy_trace(&ReadTracer {
404 data: &data[0..size],
405 });
406 }
407 self.read_buffer.extend_from_slice(&data[0..size]);
408 }
409 }
410
411 pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
419 match self.read_response(None, Duration::ZERO).await {
420 Ok(response) => Ok(Some(response)),
422 Err(ClientError::TimeOutResponse { .. }) => Ok(None),
424 Err(err) => Err(err),
426 }
427 }
428
429 pub async fn read_response(
430 &mut self,
431 command: Option<&Command>,
432 timeout_duration: Duration,
433 ) -> Result<Response, ClientError> {
434 if let Some(sock) = self.socket.as_mut() {
435 if let Err(err) = sock.flush().await {
436 self.socket.take();
437 if let Some(tracer) = &self.tracer {
438 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
439 level: Level::ERROR,
440 message: format!("Error during flush: {err:#}"),
441 });
442 tracer.trace_event(SmtpClientTraceEvent::Closed);
443 }
444 return Err(ClientError::FlushError {
445 error: format!("{err:#}"),
446 });
447 }
448 }
449
450 let mut line = self.read_line(timeout_duration, command).await?;
451 tracing::trace!("recv<-{}: {line}", self.hostname);
452 let mut parsed = parse_response_line(&line)?;
453 let mut response_builder = ResponseBuilder::new(&parsed);
454
455 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
456 while !parsed.is_final {
457 line = self
458 .read_line(subsequent_line_timeout_duration, command)
459 .await?;
460 parsed = parse_response_line(&line)?;
461 response_builder
462 .add_line(&parsed)
463 .map_err(ClientError::MalformedResponseLine)?;
464 }
465
466 let response = response_builder.build(command.map(|cmd| cmd.encode().to_string()));
467
468 tracing::trace!("{}: {response:?}", self.hostname);
469
470 Ok(response)
471 }
472
473 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
474 self.write_command_request(command).await?;
475 self.read_response(Some(command), command.client_timeout(&self.timeouts))
476 .await
477 }
478
479 async fn write_all_with_timeout<F, G>(
483 &mut self,
484 timeout_duration: Duration,
485 bytes: &[u8],
486 make_timeout_err: F,
487 make_write_err: G,
488 ) -> Result<(), ClientError>
489 where
490 F: FnOnce() -> ClientError,
491 G: FnOnce(String) -> ClientError,
492 {
493 match self.socket.as_mut() {
494 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
495 Ok(Ok(response)) => Ok(response),
496 Ok(Err(err)) => {
497 self.socket.take();
498 if let Some(tracer) = &self.tracer {
499 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
500 level: Level::ERROR,
501 message: format!("Error during write: {err:#}"),
502 });
503 tracer.trace_event(SmtpClientTraceEvent::Closed);
504 }
505 Err(make_write_err(format!("{err:#}")))
506 }
507 Err(_) => {
508 self.socket.take();
509 if let Some(tracer) = &self.tracer {
510 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
511 level: Level::ERROR,
512 message: format!("Write Timeout after {timeout_duration:?}"),
513 });
514 tracer.trace_event(SmtpClientTraceEvent::Closed);
515 }
516 Err(make_timeout_err())
517 }
518 },
519 None => Err(make_write_err(
520 "the socket was closed in response to an earlier issue".to_string(),
521 )),
522 }
523 }
524
525 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
526 let total_timeout: Duration = commands
527 .iter()
528 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
529 .sum();
530
531 let mut lines: Vec<String> = vec![];
532 let mut all: Vec<u8> = vec![];
533 for cmd in commands {
534 let line = cmd.encode();
535 all.extend_from_slice(&line);
536 lines.push(line.to_string());
537 }
538 tracing::trace!(
539 "send->{}: (PIPELINE) {}",
540 self.hostname,
541 all.as_bstr().escape_bytes()
542 );
543 if self.socket.is_some() {
544 if let Some(tracer) = &self.tracer {
545 for line in &lines {
548 WriteTracer::trace(tracer, line);
549 }
550 }
551 }
552 self.write_all_with_timeout(
553 total_timeout,
554 &all,
555 || ClientError::TimeOutRequest {
556 duration: total_timeout,
557 commands: commands.to_vec(),
558 },
559 |error| ClientError::WriteError {
560 error,
561 commands: commands.to_vec(),
562 },
563 )
564 .await
565 }
566
567 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
568 let line = command.encode();
569 tracing::trace!("send->{}: {line}", self.hostname);
570 if self.socket.is_some() {
571 if let Some(tracer) = &self.tracer {
572 WriteTracer::trace(tracer, &line.to_string());
573 }
574 }
575
576 let timeout_duration = command.client_timeout_request(&self.timeouts);
577 self.write_all_with_timeout(
578 timeout_duration,
579 &line,
580 || ClientError::TimeOutRequest {
581 duration: timeout_duration,
582 commands: vec![command.clone()],
583 },
584 |error| ClientError::WriteError {
585 error,
586 commands: vec![command.clone()],
587 },
588 )
589 .await
590 }
591
592 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
593 if self.socket.is_some() {
594 if let Some(tracer) = &self.tracer {
595 BinWriteTracer::trace(tracer, data);
596 }
597 }
598 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
599 self.write_all_with_timeout(
600 timeout_duration,
601 data,
602 || ClientError::TimeOutData,
603 |error| ClientError::WriteError {
604 error,
605 commands: vec![],
606 },
607 )
608 .await
609 }
610
611 pub async fn pipeline_commands(
627 &mut self,
628 commands: Vec<Command>,
629 ) -> Vec<Result<Response, ClientError>> {
630 let mut results: Vec<Result<Response, ClientError>> = vec![];
631
632 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
633 if pipeline {
634 if let Err(err) = self.write_pipeline_request(&commands).await {
635 let err: ClientError = err;
636 results.push(Err(err.clone()));
637 while results.len() < commands.len() {
638 results.push(Err(err.clone()));
640 }
641 return results;
642 }
643
644 for cmd in &commands {
646 results.push(
647 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
648 .await,
649 );
650 }
651 return results;
652 }
653
654 for cmd in &commands {
655 if let Err(err) = self.write_command_request(cmd).await {
656 let err: ClientError = err;
657 results.push(Err(err.clone()));
658 while results.len() < commands.len() {
659 results.push(Err(err.clone()));
661 }
662 return results;
663 }
664 results.push(
667 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
668 .await,
669 );
670 }
671 results
672 }
673
674 pub async fn ehlo_lhlo(
675 &mut self,
676 ehlo_name: &str,
677 use_lmtp: bool,
678 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
679 if use_lmtp {
680 self.lhlo(ehlo_name).await
681 } else {
682 self.ehlo(ehlo_name).await
683 }
684 }
685
686 pub async fn lhlo(
687 &mut self,
688 ehlo_name: &str,
689 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
690 let response = self
691 .send_command(&Command::Lhlo(Domain::DomainName(
692 ehlo_name
693 .parse::<DomainString>()
694 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
695 )))
696 .await?;
697 self.ehlo_common(response)
698 }
699
700 pub async fn ehlo(
701 &mut self,
702 ehlo_name: &str,
703 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
704 let response = self
705 .send_command(&Command::Ehlo(Domain::DomainName(
706 ehlo_name
707 .parse::<DomainString>()
708 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
709 )))
710 .await?;
711 self.ehlo_common(response)
712 }
713
714 fn ehlo_common(
715 &mut self,
716 response: Response,
717 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
718 if response.code != 250 {
719 return Err(ClientError::Rejected(response));
720 }
721
722 let mut capabilities = HashMap::new();
723
724 for line in response.content.lines().skip(1) {
725 let mut fields = line.splitn(2, ' ');
726 if let Some(name) = fields.next() {
727 let param = fields.next().map(|s| s.to_string());
728 let cap = EsmtpCapability {
729 name: name.to_string(),
730 param,
731 };
732 capabilities.insert(name.to_ascii_uppercase(), cap);
733 }
734 }
735
736 self.capabilities = capabilities;
737 Ok(&self.capabilities)
738 }
739
740 pub async fn auth_plain(
741 &mut self,
742 username: &str,
743 password: Option<&str>,
744 ) -> Result<(), ClientError> {
745 let password = password.unwrap_or("");
748 let payload = format!("\x00{username}\x00{password}");
749 let payload = data_encoding::BASE64.encode(payload.as_bytes());
750
751 let response = self
752 .send_command(&Command::Auth {
753 sasl_mech: "PLAIN".to_string(),
754 initial_response: Some(payload),
755 })
756 .await?;
757
758 if response.code != 235 {
759 return Err(ClientError::Rejected(response));
760 }
761
762 Ok(())
763 }
764
765 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
771 let resp = self.send_command(&Command::StartTls).await?;
772 if resp.code != 220 {
773 return Err(ClientError::Rejected(resp));
774 }
775
776 let mut handshake_error = None;
777 let mut tls_info = TlsInformation::default();
778
779 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
780 || !options.dane_tlsa.is_empty()
781 {
782 let connector = options
783 .build_openssl_connector(&self.hostname)
784 .map_err(|error| ClientError::FailedToBuildConnector {
785 error: error.to_string(),
786 })?;
787 let ssl = connector.into_ssl(self.hostname.as_str())?;
788
789 let (stream, dup_stream) = match self.socket.take() {
790 Some(s) => {
791 let d = s.try_dup();
792 (s, d)
793 }
794 None => return Err(ClientError::NotConnected),
795 };
796
797 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
798
799 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
800 handshake_error.replace(format!("{err:#}"));
801 }
802
803 tls_info.provider_name = "openssl".to_string();
804 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
805 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
806 None => String::new(),
807 };
808 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
809
810 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
811 tls_info.subject_name = subject_name(&cert);
812 }
813 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
814 if let Some(cert) = &authority.cert {
815 tls_info.subject_name = subject_name(cert);
816 }
817 }
818
819 match (&handshake_error, dup_stream) {
820 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
821 drop(ssl_stream);
826 Box::new(dup_stream)
827 }
828 _ => Box::new(ssl_stream),
829 }
830 } else {
831 tls_info.provider_name = "rustls".to_string();
832 let connector = options.build_tls_connector().await.map_err(|error| {
833 ClientError::FailedToBuildConnector {
834 error: error.to_string(),
835 }
836 })?;
837 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
838 Ok(ip) => ServerName::IpAddress(ip.into()),
839 Err(_) => ServerName::try_from(self.hostname.clone())
840 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
841 };
842
843 match connector
844 .connect(
845 server_name,
846 match self.socket.take() {
847 Some(s) => s,
848 None => return Err(ClientError::NotConnected),
849 },
850 )
851 .into_fallible()
852 .await
853 {
854 Ok(stream) => {
855 let (_, conn) = stream.get_ref();
856 tls_info.cipher = match conn.negotiated_cipher_suite() {
857 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
858 None => String::new(),
859 };
860 tls_info.protocol_version = match conn.protocol_version() {
861 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
862 None => String::new(),
863 };
864
865 if let Some(certs) = conn.peer_certificates() {
866 let peer_cert = &certs[0];
867 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
868 tls_info.subject_name = subject_name(&cert);
869 }
870 }
871
872 Box::new(stream)
873 }
874 Err((err, stream)) => {
875 handshake_error.replace(format!("{err:#}"));
876 stream
877 }
878 }
879 };
880
881 if let Some(tracer) = &self.tracer {
882 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
883 level: Level::INFO,
884 message: match &handshake_error {
885 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
886 None => format!("STARTTLS handshake -> {tls_info:?}"),
887 },
888 });
889 }
890
891 self.socket.replace(stream);
892 Ok(match handshake_error {
893 Some(error) => TlsStatus::FailedHandshake(error),
894 None => TlsStatus::Info(tls_info),
895 })
896 }
897
898 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
899 &mut self,
900 sender: SENDER,
901 recipient: RECIP,
902 data: B,
903 ) -> Result<Response, ClientError> {
904 let recipient: ForwardPath = recipient.into();
905 let status = self
906 .send_mail_multi_recip(sender, vec![recipient], data)
907 .await?;
908 Ok(status.response)
909 }
910
911 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
912 &mut self,
913 sender: SENDER,
914 recipient_list: Vec<ForwardPath>,
915 data: B,
916 ) -> Result<BatchSendSuccess, ClientError> {
917 let sender = sender.into();
918
919 let data: &[u8] = data.as_ref();
920 let stuffed;
921
922 let data = match apply_dot_stuffing(data) {
923 Some(d) => {
924 stuffed = d;
925 &stuffed
926 }
927 None => data,
928 };
929
930 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
931 let envelope_is_8bit =
932 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
933
934 let mut mail_from_params = vec![];
935 if data_is_8bit {
936 if self.capabilities.contains_key("8BITMIME") {
937 mail_from_params.push(EsmtpParameter {
938 name: "BODY".to_string(),
939 value: Some("8BITMIME".to_string()),
940 });
941 } else if !self.ignore_8bit_checks {
942 return Err(ClientError::Rejected(Response {
943 code: 554,
944 command: None,
945 enhanced_code: Some(EnhancedStatusCode {
946 class: 5,
947 subject: 6,
948 detail: 3,
949 }),
950 content: "KumoMTA internal: DATA is 8bit, destination does \
951 not support 8BITMIME. Conversion via msg:check_fix_conformance \
952 during reception is required"
953 .to_string(),
954 }));
955 }
956 }
957
958 if envelope_is_8bit {
959 if self.capabilities.contains_key("SMTPUTF8") {
960 mail_from_params.push(EsmtpParameter {
961 name: "SMTPUTF8".to_string(),
962 value: None,
963 });
964 } else if !self.ignore_8bit_checks {
965 return Err(ClientError::Rejected(Response {
966 code: 554,
967 command: None,
968 enhanced_code: Some(EnhancedStatusCode {
969 class: 5,
970 subject: 6,
971 detail: 7,
972 }),
973 content: "KumoMTA internal: envelope is 8bit, destination does \
974 not support SMTPUTF8."
975 .to_string(),
976 }));
977 }
978 }
979
980 let mut commands = vec![];
981
982 let used_rset = self.use_rset;
992 if self.use_rset {
993 commands.push(Command::Rset);
994 }
995 commands.push(Command::MailFrom {
996 address: sender,
997 parameters: mail_from_params,
998 });
999
1000 for recipient in &recipient_list {
1001 commands.push(Command::RcptTo {
1002 address: recipient.clone(),
1003 parameters: vec![],
1004 });
1005 }
1006 commands.push(Command::Data);
1007
1008 self.use_rset = true;
1011
1012 let mut responses = self.pipeline_commands(commands).await;
1013
1014 let is_err = responses.iter().any(|r| r.is_err());
1022
1023 if used_rset {
1024 let rset_resp = responses.remove(0)?;
1025 if rset_resp.code != 250 {
1026 return Err(ClientError::Rejected(rset_resp));
1027 }
1028 }
1029
1030 let mail_resp = responses.remove(0)?;
1031 if is_err && mail_resp.code != 250 {
1032 return Err(ClientError::Rejected(mail_resp));
1033 }
1034
1035 let mut rcpt_responses = vec![];
1036 for _ in &recipient_list {
1037 rcpt_responses.push(responses.remove(0)?);
1038 }
1039
1040 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1041 return Err(ClientError::RejectedBatch(rcpt_responses));
1042 }
1043
1044 let data_resp = responses.remove(0)?;
1045 if is_err && data_resp.code != 354 {
1046 return Err(ClientError::Rejected(data_resp));
1047 }
1048
1049 if data_resp.code == 354
1050 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1051 {
1052 self.write_data_with_timeout(b".\r\n").await?;
1061 let data_dot = Command::DataDot;
1062 let _ = self
1064 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1065 .await?;
1066
1067 }
1071
1072 if mail_resp.code != 250 {
1073 return Err(ClientError::Rejected(mail_resp));
1074 }
1075 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1076 if rcpt_responses.len() == 1 {
1077 return Err(ClientError::Rejected(
1078 rcpt_responses.pop().expect("have at least one"),
1079 ));
1080 }
1081 return Err(ClientError::RejectedBatch(rcpt_responses));
1082 }
1083 if data_resp.code != 354 {
1084 return Err(ClientError::Rejected(data_resp));
1085 }
1086
1087 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1088
1089 tracing::trace!("message data is {} bytes", data.len());
1090
1091 self.write_data_with_timeout(data).await?;
1092
1093 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1094
1095 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1096
1097 self.write_data_with_timeout(marker.as_bytes()).await?;
1098
1099 let data_dot = Command::DataDot;
1100 let resp = self
1101 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1102 .await?;
1103 if resp.code != 250 {
1104 return Err(ClientError::Rejected(resp));
1105 }
1106
1107 self.use_rset = self.enable_rset;
1110
1111 Ok(BatchSendSuccess {
1112 response: resp,
1113 rcpt_responses,
1114 })
1115 }
1116}
1117
1118#[derive(Debug)]
1119pub struct BatchSendSuccess {
1120 pub response: Response,
1121 pub rcpt_responses: Vec<Response>,
1122}
1123
1124#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1125pub enum TlsStatus {
1126 FailedHandshake(String),
1127 Info(TlsInformation),
1128}
1129
1130#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1131pub struct TlsInformation {
1132 pub cipher: String,
1133 pub protocol_version: String,
1134 pub subject_name: Vec<String>,
1135 pub provider_name: String,
1136}
1137
1138impl Drop for SmtpClient {
1139 fn drop(&mut self) {
1140 if let Some(tracer) = &self.tracer {
1141 if self.socket.is_some() {
1142 tracer.trace_event(SmtpClientTraceEvent::Closed);
1143 }
1144 }
1145 }
1146}
1147fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1148 if line.len() < 4 {
1149 return Err(ClientError::MalformedResponseLine(line.to_string()));
1150 }
1151
1152 match line.as_bytes()[3] {
1153 b' ' | b'-' => match line[0..3].parse::<u16>() {
1154 Ok(code) => Ok(ResponseLine {
1155 code,
1156 is_final: line.as_bytes()[3] == b' ',
1157 content: &line[4..],
1158 }),
1159 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1160 },
1161 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1162 }
1163}
1164
1165fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1166 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1167
1168 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1169 return None;
1170 }
1171
1172 let mut stuffed = vec![];
1173 if data.starts_with(b".") {
1174 stuffed.push(b'.');
1175 }
1176 let mut last_idx = 0;
1177 for i in LFDOT.find_iter(data) {
1178 stuffed.extend_from_slice(&data[last_idx..=i]);
1179 stuffed.push(b'.');
1180 last_idx = i + 1;
1181 }
1182 stuffed.extend_from_slice(&data[last_idx..]);
1183 Some(stuffed)
1184}
1185
1186pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1193 let mut subject_name = vec![];
1194 for entry in cert.subject_name().entries() {
1195 if let Ok(obj) = entry.object().nid().short_name() {
1196 if let Ok(name) = entry.data().as_utf8() {
1197 subject_name.push(format!("{obj}={name}"));
1198 }
1199 }
1200 }
1201 subject_name
1202}
1203
1204#[cfg(test)]
1205mod test {
1206 use super::*;
1207 use crate::parser::{EnvelopeAddress, ReversePath};
1208
1209 #[test]
1210 fn test_stuffing() {
1211 assert_eq!(apply_dot_stuffing(b"foo"), None);
1212 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1213 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1214 assert_eq!(
1215 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1216 b"foo\n..bar\n...baz\n"
1217 );
1218 }
1219
1220 #[test]
1244 fn response_line_parsing() {
1245 assert_eq!(
1246 parse_response_line("220 woot").unwrap(),
1247 ResponseLine {
1248 code: 220,
1249 is_final: true,
1250 content: "woot"
1251 }
1252 );
1253 assert_eq!(
1254 parse_response_line("220-woot").unwrap(),
1255 ResponseLine {
1256 code: 220,
1257 is_final: false,
1258 content: "woot"
1259 }
1260 );
1261
1262 assert!(matches!(
1263 parse_response_line("220_woot"),
1264 Err(ClientError::MalformedResponseLine(_))
1265 ));
1266 assert!(matches!(
1267 parse_response_line("not really"),
1268 Err(ClientError::MalformedResponseLine(_))
1269 ));
1270 }
1271
1272 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1273 let mut parsed = parse_response_line(lines[0])?;
1274 let mut b = ResponseBuilder::new(&parsed);
1275 for line in &lines[1..] {
1276 parsed = parse_response_line(line)?;
1277 b.add_line(&parsed)
1278 .map_err(ClientError::MalformedResponseLine)?;
1279 }
1280 assert!(parsed.is_final);
1281 Ok(b.build(None))
1282 }
1283
1284 #[test]
1285 fn multi_line_response() {
1286 assert_eq!(
1287 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1288 Response {
1289 code: 220,
1290 enhanced_code: None,
1291 content: "woot\nmore\ndone".to_string(),
1292 command: None
1293 }
1294 );
1295
1296 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1297 assert!(
1298 matches!(
1299 res,
1300 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1301 ),
1302 "got error {res:?}"
1303 );
1304
1305 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1306 assert!(
1307 matches!(
1308 res,
1309 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1310 ),
1311 "got error {res:?}"
1312 );
1313
1314 assert_eq!(
1315 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1316 Response {
1317 code: 220,
1318 enhanced_code: Some(EnhancedStatusCode {
1319 class: 4,
1320 subject: 1,
1321 detail: 0
1322 }),
1323 content: "woot\nmore\ndone".to_string(),
1324 command: None
1325 }
1326 );
1327
1328 assert_eq!(
1330 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1331 Response {
1332 code: 220,
1333 enhanced_code: Some(EnhancedStatusCode {
1334 class: 4,
1335 subject: 1,
1336 detail: 0
1337 }),
1338 content: "woot\nmore\ndone".to_string(),
1339 command: None
1340 }
1341 );
1342
1343 assert_eq!(
1345 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1346 Response {
1347 code: 220,
1348 enhanced_code: Some(EnhancedStatusCode {
1349 class: 4,
1350 subject: 1,
1351 detail: 0
1352 }),
1353 content: "woot\nmore\n5.5.5 done".to_string(),
1354 command: None
1355 }
1356 );
1357 }
1358
1359 #[test]
1360 fn test_extract_hostname() {
1361 assert_eq!(extract_hostname("foo"), "foo");
1362 assert_eq!(extract_hostname("foo."), "foo");
1363 assert_eq!(extract_hostname("foo:25"), "foo");
1364 assert_eq!(extract_hostname("foo.:25"), "foo");
1365 assert_eq!(extract_hostname("[foo]:25"), "foo");
1366 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1367 assert_eq!(extract_hostname("[::1]:25"), "::1");
1368 assert_eq!(extract_hostname("::1:25"), "::1");
1369 }
1370
1371 #[test]
1372 fn test_format_error_command() {
1373 assert_eq!(
1374 format!(
1375 "{:#}",
1376 ClientError::TimeOutRequest {
1377 commands: vec![Command::DataDot],
1378 duration: Duration::from_secs(10),
1379 }
1380 ),
1381 "Timed Out after 10s writing cmd=."
1382 );
1383 assert_eq!(
1384 format!(
1385 "{:#}",
1386 ClientError::TimeOutResponse {
1387 command: Some(Command::MailFrom {
1388 address: {
1389 let EnvelopeAddress::Path(p) =
1390 EnvelopeAddress::parse("user@host").unwrap()
1391 else {
1392 panic!("expected Path")
1393 };
1394 ReversePath::Path(p)
1395 },
1396 parameters: vec![],
1397 }),
1398 duration: Duration::from_secs(10),
1399 }
1400 ),
1401 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1402 );
1403 }
1404}