1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use {openssl, tokio_rustls};
25
26const MAX_LINE_LEN: usize = 4096;
27
28#[derive(Error, Debug, Clone)]
29pub enum ClientError {
30 #[error("response is not UTF8")]
31 Utf8(#[from] std::string::FromUtf8Error),
32 #[error("Malformed Response: {0}")]
33 MalformedResponseLine(String),
34 #[error("Response line is too long")]
35 ResponseTooLong,
36 #[error("Not connected")]
37 NotConnected,
38 #[error("Command rejected {0:?}")]
39 Rejected(Response),
40 #[error("Commands rejected {0:?}")]
41 RejectedBatch(Vec<Response>),
42 #[error("STARTTLS: {0} is not a valid DNS name")]
43 InvalidDnsName(String),
44 #[error("Invalid client certificate configured: {error:?}")]
45 FailedToBuildConnector { error: String },
46 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
47 TimeOutResponse {
48 command: Option<Command>,
49 duration: Duration,
50 },
51 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
52 TimeOutRequest {
53 commands: Vec<Command>,
54 duration: Duration,
55 },
56 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
57 ReadError {
58 command: Option<Command>,
59 error: String,
60 partial: String,
61 },
62 #[error("Error {error} flushing send buffer")]
63 FlushError { error: String },
64 #[error("Error {error} writing {}", self.command_for_err())]
65 WriteError {
66 commands: Vec<Command>,
67 error: String,
68 },
69 #[error("Timed Out sending message payload data")]
70 TimeOutData,
71 #[error("SSL Error: {0}")]
72 SslErrorStack(#[from] openssl::error::ErrorStack),
73 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
74 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
75}
76
77impl ClientError {
78 pub fn command(&self) -> Option<String> {
80 match self {
81 Self::TimeOutResponse {
82 command: Some(command),
83 ..
84 }
85 | Self::ReadError {
86 command: Some(command),
87 ..
88 } => Some(command.encode()),
89 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
90 if !commands.is_empty() =>
91 {
92 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
93 Some(commands.join(""))
94 }
95 _ => None,
96 }
97 }
98
99 fn command_for_err(&self) -> String {
102 self.command()
103 .map(|cmd| cmd.replace("\r\n", ""))
104 .unwrap_or_else(|| "NONE".to_string())
105 }
106
107 pub fn was_due_to_message(&self) -> bool {
119 match self {
120 Self::Utf8(_)
121 | Self::MalformedResponseLine(_)
122 | Self::ResponseTooLong
123 | Self::NotConnected
124 | Self::InvalidDnsName(_)
125 | Self::FailedToBuildConnector { .. }
126 | Self::TimeOutResponse { .. }
127 | Self::TimeOutRequest { .. }
128 | Self::ReadError { .. }
129 | Self::FlushError { .. }
130 | Self::WriteError { .. }
131 | Self::TimeOutData
132 | Self::SslErrorStack(_)
133 | Self::NoUsableDaneTlsa { .. } => false,
134 Self::Rejected(response) => response.was_due_to_message(),
135 Self::RejectedBatch(responses) => match responses.len() {
136 1 => responses[0].was_due_to_message(),
137 _ => false,
138 },
139 }
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct EsmtpCapability {
145 pub name: String,
146 pub param: Option<String>,
147}
148
149#[derive(Clone, Debug)]
150pub enum SmtpClientTraceEvent {
151 Closed,
152 Read(Vec<u8>),
153 Write(String),
154 Diagnostic {
155 level: tracing::Level,
156 message: String,
157 },
158}
159
160pub trait DeferredTracer {
161 fn trace(&self) -> SmtpClientTraceEvent;
162}
163
164pub trait SmtpClientTracer: std::fmt::Debug {
165 fn trace_event(&self, event: SmtpClientTraceEvent);
166 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
167}
168
169struct WriteTracer<'a> {
171 data: &'a str,
172}
173impl DeferredTracer for WriteTracer<'_> {
174 fn trace(&self) -> SmtpClientTraceEvent {
175 SmtpClientTraceEvent::Write(self.data.to_string())
176 }
177}
178impl<'a> WriteTracer<'a> {
179 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
180 tracer.lazy_trace(&Self { data });
181 }
182}
183
184struct BinWriteTracer<'a> {
185 data: &'a [u8],
186}
187impl DeferredTracer for BinWriteTracer<'_> {
188 fn trace(&self) -> SmtpClientTraceEvent {
189 let data = String::from_utf8_lossy(self.data).to_string();
190 SmtpClientTraceEvent::Write(data)
191 }
192}
193impl<'a> BinWriteTracer<'a> {
194 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
195 tracer.lazy_trace(&Self { data });
196 }
197}
198
199struct ReadTracer<'a> {
202 data: &'a [u8],
203}
204impl DeferredTracer for ReadTracer<'_> {
205 fn trace(&self) -> SmtpClientTraceEvent {
206 SmtpClientTraceEvent::Read(self.data.to_vec())
207 }
208}
209
210#[derive(Debug)]
211pub struct SmtpClient {
212 socket: Option<BoxedAsyncReadAndWrite>,
213 hostname: String,
214 capabilities: HashMap<String, EsmtpCapability>,
215 read_buffer: Vec<u8>,
216 timeouts: SmtpClientTimeouts,
217 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
218 use_rset: bool,
219 enable_rset: bool,
220 enable_pipelining: bool,
221 ignore_8bit_checks: bool,
222}
223
224fn extract_hostname(hostname: &str) -> &str {
225 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
227 let hostname = if fields.len() == 2 {
228 fields[1]
229 } else {
230 hostname
231 };
232
233 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
234 &hostname[1..hostname.len() - 1]
235 } else {
236 hostname
237 };
238
239 hostname.strip_suffix(".").unwrap_or(hostname)
241}
242
243impl SmtpClient {
244 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
245 addr: A,
246 timeouts: SmtpClientTimeouts,
247 ) -> std::io::Result<Self> {
248 let stream = TcpStream::connect(addr.clone()).await?;
249 stream.set_nodelay(true)?;
251 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
252 }
253
254 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
255 stream: S,
256 peer_hostname: H,
257 timeouts: SmtpClientTimeouts,
258 ) -> Self {
259 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
260
261 Self {
262 socket: Some(Box::new(stream)),
263 hostname,
264 capabilities: HashMap::new(),
265 read_buffer: Vec::with_capacity(1024),
266 timeouts,
267 tracer: None,
268 use_rset: false,
269 enable_rset: false,
270 enable_pipelining: false,
271 ignore_8bit_checks: false,
272 }
273 }
274
275 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
277 self.ignore_8bit_checks = enable;
278 }
279
280 pub fn is_connected(&self) -> bool {
281 self.socket.is_some()
282 }
283
284 pub fn set_enable_rset(&mut self, enable: bool) {
285 self.enable_rset = enable;
286 }
287
288 pub fn set_enable_pipelining(&mut self, enable: bool) {
289 self.enable_pipelining = enable;
290 }
291
292 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
293 self.tracer.replace(tracer);
294 }
295
296 pub fn timeouts(&self) -> &SmtpClientTimeouts {
297 &self.timeouts
298 }
299
300 async fn read_line(
301 &mut self,
302 timeout_duration: Duration,
303 cmd: Option<&Command>,
304 ) -> Result<String, ClientError> {
305 let mut too_long = false;
306 loop {
307 let mut iter = self.read_buffer.iter().enumerate();
308 while let Some((i, &b)) = iter.next() {
309 if b != b'\r' {
310 continue;
311 }
312 if let Some((_, b'\n')) = iter.next() {
313 if too_long {
314 self.read_buffer.drain(0..i + 2);
315
316 if let Some(tracer) = &self.tracer {
317 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
318 level: Level::ERROR,
319 message: "Response too long".to_string(),
320 });
321 }
322
323 return Err(ClientError::ResponseTooLong);
324 }
325
326 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
327 self.read_buffer.drain(0..i + 2);
328 return Ok(line?);
329 }
330 }
331 if self.read_buffer.len() > MAX_LINE_LEN {
332 self.read_buffer.clear();
333 too_long = true;
334 }
335
336 let mut data = [0u8; MAX_LINE_LEN];
338 let size = match self.socket.as_mut() {
339 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
340 Ok(Ok(size)) => size,
341 Ok(Err(err)) => {
342 self.socket.take();
343 if let Some(tracer) = &self.tracer {
344 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
345 level: Level::ERROR,
346 message: format!("Error during read: {err:#}"),
347 });
348 tracer.trace_event(SmtpClientTraceEvent::Closed);
349 }
350 return Err(ClientError::ReadError {
351 command: cmd.cloned(),
352 error: format!("{err:#}"),
353 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354 });
355 }
356 Err(_) => {
357 self.socket.take();
358 if let Some(tracer) = &self.tracer {
359 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
360 level: Level::ERROR,
361 message: format!("Read Timeout after {timeout_duration:?}"),
362 });
363 tracer.trace_event(SmtpClientTraceEvent::Closed);
364 }
365 return Err(ClientError::TimeOutResponse {
366 command: cmd.cloned(),
367 duration: timeout_duration,
368 });
369 }
370 },
371 None => {
372 return Err(ClientError::ReadError {
373 command: cmd.cloned(),
374 error: "the socket was closed in response to an earlier issue".to_string(),
375 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
376 });
377 }
378 };
379 if size == 0 {
380 self.socket.take();
381 if let Some(tracer) = &self.tracer {
382 tracer.trace_event(SmtpClientTraceEvent::Closed);
383 }
384 return Err(ClientError::ReadError {
385 command: cmd.cloned(),
386 error: "Connection closed by peer".to_string(),
387 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
388 });
389 }
390 if let Some(tracer) = &self.tracer {
391 tracer.lazy_trace(&ReadTracer {
392 data: &data[0..size],
393 });
394 }
395 self.read_buffer.extend_from_slice(&data[0..size]);
396 }
397 }
398
399 pub async fn read_response(
400 &mut self,
401 command: Option<&Command>,
402 timeout_duration: Duration,
403 ) -> Result<Response, ClientError> {
404 if let Some(sock) = self.socket.as_mut() {
405 if let Err(err) = sock.flush().await {
406 self.socket.take();
407 if let Some(tracer) = &self.tracer {
408 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
409 level: Level::ERROR,
410 message: format!("Error during flush: {err:#}"),
411 });
412 tracer.trace_event(SmtpClientTraceEvent::Closed);
413 }
414 return Err(ClientError::FlushError {
415 error: format!("{err:#}"),
416 });
417 }
418 }
419
420 let mut line = self.read_line(timeout_duration, command).await?;
421 tracing::trace!("recv<-{}: {line}", self.hostname);
422 let mut parsed = parse_response_line(&line)?;
423 let mut response_builder = ResponseBuilder::new(&parsed);
424
425 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
426 while !parsed.is_final {
427 line = self
428 .read_line(subsequent_line_timeout_duration, command)
429 .await?;
430 parsed = parse_response_line(&line)?;
431 response_builder
432 .add_line(&parsed)
433 .map_err(ClientError::MalformedResponseLine)?;
434 }
435
436 let response = response_builder.build(command.map(|cmd| cmd.encode()));
437
438 tracing::trace!("{}: {response:?}", self.hostname);
439
440 Ok(response)
441 }
442
443 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
444 self.write_command_request(command).await?;
445 self.read_response(Some(command), command.client_timeout(&self.timeouts))
446 .await
447 }
448
449 async fn write_all_with_timeout<F, G>(
453 &mut self,
454 timeout_duration: Duration,
455 bytes: &[u8],
456 make_timeout_err: F,
457 make_write_err: G,
458 ) -> Result<(), ClientError>
459 where
460 F: FnOnce() -> ClientError,
461 G: FnOnce(String) -> ClientError,
462 {
463 match self.socket.as_mut() {
464 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
465 Ok(Ok(response)) => Ok(response),
466 Ok(Err(err)) => {
467 self.socket.take();
468 if let Some(tracer) = &self.tracer {
469 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
470 level: Level::ERROR,
471 message: format!("Error during write: {err:#}"),
472 });
473 tracer.trace_event(SmtpClientTraceEvent::Closed);
474 }
475 Err(make_write_err(format!("{err:#}")))
476 }
477 Err(_) => {
478 self.socket.take();
479 if let Some(tracer) = &self.tracer {
480 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
481 level: Level::ERROR,
482 message: format!("Write Timeout after {timeout_duration:?}"),
483 });
484 tracer.trace_event(SmtpClientTraceEvent::Closed);
485 }
486 Err(make_timeout_err())
487 }
488 },
489 None => Err(make_write_err(
490 "the socket was closed in response to an earlier issue".to_string(),
491 )),
492 }
493 }
494
495 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
496 let total_timeout: Duration = commands
497 .iter()
498 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
499 .sum();
500
501 let mut lines: Vec<String> = vec![];
502 let mut all = String::new();
503 for cmd in commands {
504 let line = cmd.encode();
505 all.push_str(&line);
506 lines.push(line);
507 }
508 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
509 if self.socket.is_some() {
510 if let Some(tracer) = &self.tracer {
511 for line in lines {
514 WriteTracer::trace(tracer, &line);
515 }
516 }
517 }
518 self.write_all_with_timeout(
519 total_timeout,
520 all.as_bytes(),
521 || ClientError::TimeOutRequest {
522 duration: total_timeout,
523 commands: commands.to_vec(),
524 },
525 |error| ClientError::WriteError {
526 error,
527 commands: commands.to_vec(),
528 },
529 )
530 .await
531 }
532
533 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
534 let line = command.encode();
535 tracing::trace!("send->{}: {line}", self.hostname);
536 if self.socket.is_some() {
537 if let Some(tracer) = &self.tracer {
538 WriteTracer::trace(tracer, &line);
539 }
540 }
541
542 let timeout_duration = command.client_timeout_request(&self.timeouts);
543 self.write_all_with_timeout(
544 timeout_duration,
545 line.as_bytes(),
546 || ClientError::TimeOutRequest {
547 duration: timeout_duration,
548 commands: vec![command.clone()],
549 },
550 |error| ClientError::WriteError {
551 error,
552 commands: vec![command.clone()],
553 },
554 )
555 .await
556 }
557
558 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
559 if self.socket.is_some() {
560 if let Some(tracer) = &self.tracer {
561 BinWriteTracer::trace(tracer, data);
562 }
563 }
564 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
565 self.write_all_with_timeout(
566 timeout_duration,
567 data,
568 || ClientError::TimeOutData,
569 |error| ClientError::WriteError {
570 error,
571 commands: vec![],
572 },
573 )
574 .await
575 }
576
577 pub async fn pipeline_commands(
593 &mut self,
594 commands: Vec<Command>,
595 ) -> Vec<Result<Response, ClientError>> {
596 let mut results: Vec<Result<Response, ClientError>> = vec![];
597
598 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
599 if pipeline {
600 if let Err(err) = self.write_pipeline_request(&commands).await {
601 let err: ClientError = err;
602 results.push(Err(err.clone()));
603 while results.len() < commands.len() {
604 results.push(Err(err.clone()));
606 }
607 return results;
608 }
609
610 for cmd in &commands {
612 results.push(
613 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
614 .await,
615 );
616 }
617 return results;
618 }
619
620 for cmd in &commands {
621 if let Err(err) = self.write_command_request(cmd).await {
622 let err: ClientError = err;
623 results.push(Err(err.clone()));
624 while results.len() < commands.len() {
625 results.push(Err(err.clone()));
627 }
628 return results;
629 }
630 results.push(
633 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
634 .await,
635 );
636 }
637 results
638 }
639
640 pub async fn ehlo_lhlo(
641 &mut self,
642 ehlo_name: &str,
643 use_lmtp: bool,
644 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
645 if use_lmtp {
646 self.lhlo(ehlo_name).await
647 } else {
648 self.ehlo(ehlo_name).await
649 }
650 }
651
652 pub async fn lhlo(
653 &mut self,
654 ehlo_name: &str,
655 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
656 let response = self
657 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
658 .await?;
659 self.ehlo_common(response)
660 }
661
662 pub async fn ehlo(
663 &mut self,
664 ehlo_name: &str,
665 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
666 let response = self
667 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
668 .await?;
669 self.ehlo_common(response)
670 }
671
672 fn ehlo_common(
673 &mut self,
674 response: Response,
675 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
676 if response.code != 250 {
677 return Err(ClientError::Rejected(response));
678 }
679
680 let mut capabilities = HashMap::new();
681
682 for line in response.content.lines().skip(1) {
683 let mut fields = line.splitn(2, ' ');
684 if let Some(name) = fields.next() {
685 let param = fields.next().map(|s| s.to_string());
686 let cap = EsmtpCapability {
687 name: name.to_string(),
688 param,
689 };
690 capabilities.insert(name.to_ascii_uppercase(), cap);
691 }
692 }
693
694 self.capabilities = capabilities;
695 Ok(&self.capabilities)
696 }
697
698 pub async fn auth_plain(
699 &mut self,
700 username: &str,
701 password: Option<&str>,
702 ) -> Result<(), ClientError> {
703 let password = password.unwrap_or("");
706 let payload = format!("\x00{username}\x00{password}");
707 let payload = data_encoding::BASE64.encode(payload.as_bytes());
708
709 let response = self
710 .send_command(&Command::Auth {
711 sasl_mech: "PLAIN".to_string(),
712 initial_response: Some(payload),
713 })
714 .await?;
715
716 if response.code != 235 {
717 return Err(ClientError::Rejected(response));
718 }
719
720 Ok(())
721 }
722
723 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
729 let resp = self.send_command(&Command::StartTls).await?;
730 if resp.code != 220 {
731 return Err(ClientError::Rejected(resp));
732 }
733
734 let mut handshake_error = None;
735 let mut tls_info = TlsInformation::default();
736
737 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
738 || !options.dane_tlsa.is_empty()
739 {
740 let connector = options
741 .build_openssl_connector(&self.hostname)
742 .map_err(|error| ClientError::FailedToBuildConnector {
743 error: error.to_string(),
744 })?;
745 let ssl = connector.into_ssl(self.hostname.as_str())?;
746
747 let (stream, dup_stream) = match self.socket.take() {
748 Some(s) => {
749 let d = s.try_dup();
750 (s, d)
751 }
752 None => return Err(ClientError::NotConnected),
753 };
754
755 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
756
757 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
758 handshake_error.replace(format!("{err:#}"));
759 }
760
761 tls_info.provider_name = "openssl".to_string();
762 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
763 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
764 None => String::new(),
765 };
766 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
767
768 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
769 tls_info.subject_name = subject_name(&cert);
770 }
771 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
772 if let Some(cert) = &authority.cert {
773 tls_info.subject_name = subject_name(cert);
774 }
775 }
776
777 match (&handshake_error, dup_stream) {
778 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
779 drop(ssl_stream);
784 Box::new(dup_stream)
785 }
786 _ => Box::new(ssl_stream),
787 }
788 } else {
789 tls_info.provider_name = "rustls".to_string();
790 let connector = options.build_tls_connector().await.map_err(|error| {
791 ClientError::FailedToBuildConnector {
792 error: error.to_string(),
793 }
794 })?;
795 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
796 Ok(ip) => ServerName::IpAddress(ip.into()),
797 Err(_) => ServerName::try_from(self.hostname.clone())
798 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
799 };
800
801 match connector
802 .connect(
803 server_name,
804 match self.socket.take() {
805 Some(s) => s,
806 None => return Err(ClientError::NotConnected),
807 },
808 )
809 .into_fallible()
810 .await
811 {
812 Ok(stream) => {
813 let (_, conn) = stream.get_ref();
814 tls_info.cipher = match conn.negotiated_cipher_suite() {
815 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
816 None => String::new(),
817 };
818 tls_info.protocol_version = match conn.protocol_version() {
819 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
820 None => String::new(),
821 };
822
823 if let Some(certs) = conn.peer_certificates() {
824 let peer_cert = &certs[0];
825 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
826 tls_info.subject_name = subject_name(&cert);
827 }
828 }
829
830 Box::new(stream)
831 }
832 Err((err, stream)) => {
833 handshake_error.replace(format!("{err:#}"));
834 stream
835 }
836 }
837 };
838
839 if let Some(tracer) = &self.tracer {
840 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
841 level: Level::INFO,
842 message: match &handshake_error {
843 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
844 None => format!("STARTTLS handshake -> {tls_info:?}"),
845 },
846 });
847 }
848
849 self.socket.replace(stream);
850 Ok(match handshake_error {
851 Some(error) => TlsStatus::FailedHandshake(error),
852 None => TlsStatus::Info(tls_info),
853 })
854 }
855
856 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
857 &mut self,
858 sender: SENDER,
859 recipient: RECIP,
860 data: B,
861 ) -> Result<Response, ClientError> {
862 let recipient: ForwardPath = recipient.into();
863 let status = self
864 .send_mail_multi_recip(sender, vec![recipient], data)
865 .await?;
866 Ok(status.response)
867 }
868
869 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
870 &mut self,
871 sender: SENDER,
872 recipient_list: Vec<ForwardPath>,
873 data: B,
874 ) -> Result<BatchSendSuccess, ClientError> {
875 let sender = sender.into();
876
877 let data: &[u8] = data.as_ref();
878 let stuffed;
879
880 let data = match apply_dot_stuffing(data) {
881 Some(d) => {
882 stuffed = d;
883 &stuffed
884 }
885 None => data,
886 };
887
888 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
889 let envelope_is_8bit =
890 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
891
892 let mut mail_from_params = vec![];
893 if data_is_8bit {
894 if self.capabilities.contains_key("8BITMIME") {
895 mail_from_params.push(EsmtpParameter {
896 name: "BODY".to_string(),
897 value: Some("8BITMIME".to_string()),
898 });
899 } else if !self.ignore_8bit_checks {
900 return Err(ClientError::Rejected(Response {
901 code: 554,
902 command: None,
903 enhanced_code: Some(EnhancedStatusCode {
904 class: 5,
905 subject: 6,
906 detail: 3,
907 }),
908 content: "KumoMTA internal: DATA is 8bit, destination does \
909 not support 8BITMIME. Conversion via msg:check_fix_conformance \
910 during reception is required"
911 .to_string(),
912 }));
913 }
914 }
915
916 if envelope_is_8bit {
917 if self.capabilities.contains_key("SMTPUTF8") {
918 mail_from_params.push(EsmtpParameter {
919 name: "SMTPUTF8".to_string(),
920 value: None,
921 });
922 } else if !self.ignore_8bit_checks {
923 return Err(ClientError::Rejected(Response {
924 code: 554,
925 command: None,
926 enhanced_code: Some(EnhancedStatusCode {
927 class: 5,
928 subject: 6,
929 detail: 7,
930 }),
931 content: "KumoMTA internal: envelope is 8bit, destination does \
932 not support SMTPUTF8."
933 .to_string(),
934 }));
935 }
936 }
937
938 let mut commands = vec![];
939
940 let used_rset = self.use_rset;
950 if self.use_rset {
951 commands.push(Command::Rset);
952 }
953 commands.push(Command::MailFrom {
954 address: sender,
955 parameters: mail_from_params,
956 });
957
958 for recipient in &recipient_list {
959 commands.push(Command::RcptTo {
960 address: recipient.clone(),
961 parameters: vec![],
962 });
963 }
964 commands.push(Command::Data);
965
966 self.use_rset = true;
969
970 let mut responses = self.pipeline_commands(commands).await;
971
972 let is_err = responses.iter().any(|r| r.is_err());
980
981 if used_rset {
982 let rset_resp = responses.remove(0)?;
983 if rset_resp.code != 250 {
984 return Err(ClientError::Rejected(rset_resp));
985 }
986 }
987
988 let mail_resp = responses.remove(0)?;
989 if is_err && mail_resp.code != 250 {
990 return Err(ClientError::Rejected(mail_resp));
991 }
992
993 let mut rcpt_responses = vec![];
994 for _ in &recipient_list {
995 rcpt_responses.push(responses.remove(0)?);
996 }
997
998 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
999 return Err(ClientError::RejectedBatch(rcpt_responses));
1000 }
1001
1002 let data_resp = responses.remove(0)?;
1003 if is_err && data_resp.code != 354 {
1004 return Err(ClientError::Rejected(data_resp));
1005 }
1006
1007 if data_resp.code == 354
1008 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1009 {
1010 self.write_data_with_timeout(b".\r\n").await?;
1019 let data_dot = Command::DataDot;
1020 let _ = self
1022 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1023 .await?;
1024
1025 }
1029
1030 if mail_resp.code != 250 {
1031 return Err(ClientError::Rejected(mail_resp));
1032 }
1033 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1034 if rcpt_responses.len() == 1 {
1035 return Err(ClientError::Rejected(
1036 rcpt_responses.pop().expect("have at least one"),
1037 ));
1038 }
1039 return Err(ClientError::RejectedBatch(rcpt_responses));
1040 }
1041 if data_resp.code != 354 {
1042 return Err(ClientError::Rejected(data_resp));
1043 }
1044
1045 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1046
1047 tracing::trace!("message data is {} bytes", data.len());
1048
1049 self.write_data_with_timeout(data).await?;
1050
1051 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1052
1053 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1054
1055 self.write_data_with_timeout(marker.as_bytes()).await?;
1056
1057 let data_dot = Command::DataDot;
1058 let resp = self
1059 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1060 .await?;
1061 if resp.code != 250 {
1062 return Err(ClientError::Rejected(resp));
1063 }
1064
1065 self.use_rset = self.enable_rset;
1068
1069 Ok(BatchSendSuccess {
1070 response: resp,
1071 rcpt_responses,
1072 })
1073 }
1074}
1075
1076#[derive(Debug)]
1077pub struct BatchSendSuccess {
1078 pub response: Response,
1079 pub rcpt_responses: Vec<Response>,
1080}
1081
1082#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1083pub enum TlsStatus {
1084 FailedHandshake(String),
1085 Info(TlsInformation),
1086}
1087
1088#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1089pub struct TlsInformation {
1090 pub cipher: String,
1091 pub protocol_version: String,
1092 pub subject_name: Vec<String>,
1093 pub provider_name: String,
1094}
1095
1096impl Drop for SmtpClient {
1097 fn drop(&mut self) {
1098 if let Some(tracer) = &self.tracer {
1099 if self.socket.is_some() {
1100 tracer.trace_event(SmtpClientTraceEvent::Closed);
1101 }
1102 }
1103 }
1104}
1105fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1106 if line.len() < 4 {
1107 return Err(ClientError::MalformedResponseLine(line.to_string()));
1108 }
1109
1110 match line.as_bytes()[3] {
1111 b' ' | b'-' => match line[0..3].parse::<u16>() {
1112 Ok(code) => Ok(ResponseLine {
1113 code,
1114 is_final: line.as_bytes()[3] == b' ',
1115 content: &line[4..],
1116 }),
1117 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1118 },
1119 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1120 }
1121}
1122
1123fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1124 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1125
1126 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1127 return None;
1128 }
1129
1130 let mut stuffed = vec![];
1131 if data.starts_with(b".") {
1132 stuffed.push(b'.');
1133 }
1134 let mut last_idx = 0;
1135 for i in LFDOT.find_iter(data) {
1136 stuffed.extend_from_slice(&data[last_idx..=i]);
1137 stuffed.push(b'.');
1138 last_idx = i + 1;
1139 }
1140 stuffed.extend_from_slice(&data[last_idx..]);
1141 Some(stuffed)
1142}
1143
1144pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1151 let mut subject_name = vec![];
1152 for entry in cert.subject_name().entries() {
1153 if let Ok(obj) = entry.object().nid().short_name() {
1154 if let Ok(name) = entry.data().as_utf8() {
1155 subject_name.push(format!("{obj}={name}"));
1156 }
1157 }
1158 }
1159 subject_name
1160}
1161
1162#[cfg(test)]
1163mod test {
1164 use super::*;
1165 use crate::{MailPath, Mailbox};
1166
1167 #[test]
1168 fn test_stuffing() {
1169 assert_eq!(apply_dot_stuffing(b"foo"), None);
1170 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1171 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1172 assert_eq!(
1173 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1174 b"foo\n..bar\n...baz\n"
1175 );
1176 }
1177
1178 #[test]
1202 fn response_line_parsing() {
1203 assert_eq!(
1204 parse_response_line("220 woot").unwrap(),
1205 ResponseLine {
1206 code: 220,
1207 is_final: true,
1208 content: "woot"
1209 }
1210 );
1211 assert_eq!(
1212 parse_response_line("220-woot").unwrap(),
1213 ResponseLine {
1214 code: 220,
1215 is_final: false,
1216 content: "woot"
1217 }
1218 );
1219
1220 assert!(matches!(
1221 parse_response_line("220_woot"),
1222 Err(ClientError::MalformedResponseLine(_))
1223 ));
1224 assert!(matches!(
1225 parse_response_line("not really"),
1226 Err(ClientError::MalformedResponseLine(_))
1227 ));
1228 }
1229
1230 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1231 let mut parsed = parse_response_line(lines[0])?;
1232 let mut b = ResponseBuilder::new(&parsed);
1233 for line in &lines[1..] {
1234 parsed = parse_response_line(line)?;
1235 b.add_line(&parsed)
1236 .map_err(ClientError::MalformedResponseLine)?;
1237 }
1238 assert!(parsed.is_final);
1239 Ok(b.build(None))
1240 }
1241
1242 #[test]
1243 fn multi_line_response() {
1244 assert_eq!(
1245 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1246 Response {
1247 code: 220,
1248 enhanced_code: None,
1249 content: "woot\nmore\ndone".to_string(),
1250 command: None
1251 }
1252 );
1253
1254 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1255 assert!(
1256 matches!(
1257 res,
1258 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1259 ),
1260 "got error {res:?}"
1261 );
1262
1263 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1264 assert!(
1265 matches!(
1266 res,
1267 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1268 ),
1269 "got error {res:?}"
1270 );
1271
1272 assert_eq!(
1273 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1274 Response {
1275 code: 220,
1276 enhanced_code: Some(EnhancedStatusCode {
1277 class: 4,
1278 subject: 1,
1279 detail: 0
1280 }),
1281 content: "woot\nmore\ndone".to_string(),
1282 command: None
1283 }
1284 );
1285
1286 assert_eq!(
1288 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1289 Response {
1290 code: 220,
1291 enhanced_code: Some(EnhancedStatusCode {
1292 class: 4,
1293 subject: 1,
1294 detail: 0
1295 }),
1296 content: "woot\nmore\ndone".to_string(),
1297 command: None
1298 }
1299 );
1300
1301 assert_eq!(
1303 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1304 Response {
1305 code: 220,
1306 enhanced_code: Some(EnhancedStatusCode {
1307 class: 4,
1308 subject: 1,
1309 detail: 0
1310 }),
1311 content: "woot\nmore\n5.5.5 done".to_string(),
1312 command: None
1313 }
1314 );
1315 }
1316
1317 #[test]
1318 fn test_extract_hostname() {
1319 assert_eq!(extract_hostname("foo"), "foo");
1320 assert_eq!(extract_hostname("foo."), "foo");
1321 assert_eq!(extract_hostname("foo:25"), "foo");
1322 assert_eq!(extract_hostname("foo.:25"), "foo");
1323 assert_eq!(extract_hostname("[foo]:25"), "foo");
1324 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1325 assert_eq!(extract_hostname("[::1]:25"), "::1");
1326 assert_eq!(extract_hostname("::1:25"), "::1");
1327 }
1328
1329 #[test]
1330 fn test_format_error_command() {
1331 assert_eq!(
1332 format!(
1333 "{:#}",
1334 ClientError::TimeOutRequest {
1335 commands: vec![Command::DataDot],
1336 duration: Duration::from_secs(10),
1337 }
1338 ),
1339 "Timed Out after 10s writing cmd=."
1340 );
1341 assert_eq!(
1342 format!(
1343 "{:#}",
1344 ClientError::TimeOutResponse {
1345 command: Some(Command::MailFrom {
1346 address: ReversePath::Path(MailPath {
1347 at_domain_list: vec![],
1348 mailbox: Mailbox {
1349 local_part: "user".to_string(),
1350 domain: Domain::Name("host".to_string())
1351 }
1352 }),
1353 parameters: vec![],
1354 }),
1355 duration: Duration::from_secs(10),
1356 }
1357 ),
1358 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1359 );
1360 }
1361}