1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::parser::{Command, Domain, EsmtpParameter, ForwardPath, ReversePath};
4use crate::{AsyncReadAndWrite, BoxedAsyncReadAndWrite};
5use bstr::ByteSlice;
6use hickory_proto::rr::rdata::TLSA;
7use hickory_proto::rr::Name;
8use memchr::memmem::Finder;
9use nom_utils::DomainString;
10use openssl::x509::{X509Ref, X509};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::IpAddr;
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::time::Duration;
17use thiserror::Error;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::{TcpStream, ToSocketAddrs};
20use tokio::time::timeout;
21use tokio_rustls::rustls::pki_types::ServerName;
22use tracing::Level;
23
24pub use kumo_tls_helper::TlsOptions;
25pub use openssl;
26pub use tokio_rustls;
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32 #[error("response is not UTF8")]
33 Utf8(#[from] std::string::FromUtf8Error),
34 #[error("Malformed Response: {0}")]
35 MalformedResponseLine(String),
36 #[error("Response line is too long")]
37 ResponseTooLong,
38 #[error("Not connected")]
39 NotConnected,
40 #[error("Command rejected {0:?}")]
41 Rejected(Response),
42 #[error("Commands rejected {0:?}")]
43 RejectedBatch(Vec<Response>),
44 #[error("STARTTLS: {0} is not a valid DNS name")]
45 InvalidDnsName(String),
46 #[error("Invalid client certificate configured: {error:?}")]
47 FailedToBuildConnector { error: String },
48 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
49 TimeOutResponse {
50 command: Option<Command>,
51 duration: Duration,
52 },
53 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
54 TimeOutRequest {
55 commands: Vec<Command>,
56 duration: Duration,
57 },
58 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
59 ReadError {
60 command: Option<Command>,
61 error: String,
62 partial: String,
63 },
64 #[error("Error {error} flushing send buffer")]
65 FlushError { error: String },
66 #[error("Error {error} writing {}", self.command_for_err())]
67 WriteError {
68 commands: Vec<Command>,
69 error: String,
70 },
71 #[error("Timed Out sending message payload data")]
72 TimeOutData,
73 #[error("SSL Error: {0}")]
74 SslErrorStack(#[from] openssl::error::ErrorStack),
75 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
76 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
77}
78
79impl ClientError {
80 pub fn command(&self) -> Option<String> {
82 match self {
83 Self::TimeOutResponse {
84 command: Some(command),
85 ..
86 }
87 | Self::ReadError {
88 command: Some(command),
89 ..
90 } => Some(command.encode().to_string()),
91 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
92 if !commands.is_empty() =>
93 {
94 let s: String = commands
95 .iter()
96 .map(|cmd| cmd.encode().to_string())
97 .collect();
98 Some(s)
99 }
100 _ => None,
101 }
102 }
103
104 fn command_for_err(&self) -> String {
107 self.command()
108 .map(|cmd| cmd.replace("\r\n", ""))
109 .unwrap_or_else(|| "NONE".to_string())
110 }
111
112 pub fn was_due_to_message(&self) -> bool {
124 match self {
125 Self::Utf8(_)
126 | Self::MalformedResponseLine(_)
127 | Self::ResponseTooLong
128 | Self::NotConnected
129 | Self::InvalidDnsName(_)
130 | Self::FailedToBuildConnector { .. }
131 | Self::TimeOutResponse { .. }
132 | Self::TimeOutRequest { .. }
133 | Self::ReadError { .. }
134 | Self::FlushError { .. }
135 | Self::WriteError { .. }
136 | Self::TimeOutData
137 | Self::SslErrorStack(_)
138 | Self::NoUsableDaneTlsa { .. } => false,
139 Self::Rejected(response) => response.was_due_to_message(),
140 Self::RejectedBatch(responses) => match responses.len() {
141 1 => responses[0].was_due_to_message(),
142 _ => false,
143 },
144 }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct EsmtpCapability {
150 pub name: String,
151 pub param: Option<String>,
152}
153
154#[derive(Clone, Debug)]
155pub enum SmtpClientTraceEvent {
156 Closed,
157 Read(Vec<u8>),
158 Write(String),
159 Diagnostic {
160 level: tracing::Level,
161 message: String,
162 },
163}
164
165pub trait DeferredTracer {
166 fn trace(&self) -> SmtpClientTraceEvent;
167}
168
169pub trait SmtpClientTracer: std::fmt::Debug {
170 fn trace_event(&self, event: SmtpClientTraceEvent);
171 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
172}
173
174struct WriteTracer<'a> {
176 data: &'a str,
177}
178impl DeferredTracer for WriteTracer<'_> {
179 fn trace(&self) -> SmtpClientTraceEvent {
180 SmtpClientTraceEvent::Write(self.data.to_string())
181 }
182}
183impl<'a> WriteTracer<'a> {
184 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
185 tracer.lazy_trace(&Self { data });
186 }
187}
188
189struct BinWriteTracer<'a> {
190 data: &'a [u8],
191}
192impl DeferredTracer for BinWriteTracer<'_> {
193 fn trace(&self) -> SmtpClientTraceEvent {
194 let data = String::from_utf8_lossy(self.data).to_string();
195 SmtpClientTraceEvent::Write(data)
196 }
197}
198impl<'a> BinWriteTracer<'a> {
199 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
200 tracer.lazy_trace(&Self { data });
201 }
202}
203
204struct ReadTracer<'a> {
207 data: &'a [u8],
208}
209impl DeferredTracer for ReadTracer<'_> {
210 fn trace(&self) -> SmtpClientTraceEvent {
211 SmtpClientTraceEvent::Read(self.data.to_vec())
212 }
213}
214
215#[derive(Debug)]
216pub struct SmtpClient {
217 socket: Option<BoxedAsyncReadAndWrite>,
218 hostname: String,
219 capabilities: HashMap<String, EsmtpCapability>,
220 read_buffer: Vec<u8>,
221 timeouts: SmtpClientTimeouts,
222 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
223 use_rset: bool,
224 enable_rset: bool,
225 enable_pipelining: bool,
226 ignore_8bit_checks: bool,
227}
228
229fn extract_hostname(hostname: &str) -> &str {
230 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
232 let hostname = if fields.len() == 2 {
233 fields[1]
234 } else {
235 hostname
236 };
237
238 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
239 &hostname[1..hostname.len() - 1]
240 } else {
241 hostname
242 };
243
244 hostname.strip_suffix(".").unwrap_or(hostname)
246}
247
248impl SmtpClient {
249 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
250 addr: A,
251 timeouts: SmtpClientTimeouts,
252 ) -> std::io::Result<Self> {
253 let stream = TcpStream::connect(addr.clone()).await?;
254 stream.set_nodelay(true)?;
256 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
257 }
258
259 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
260 stream: S,
261 peer_hostname: H,
262 timeouts: SmtpClientTimeouts,
263 ) -> Self {
264 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
265
266 Self {
267 socket: Some(Box::new(stream)),
268 hostname,
269 capabilities: HashMap::new(),
270 read_buffer: Vec::with_capacity(1024),
271 timeouts,
272 tracer: None,
273 use_rset: false,
274 enable_rset: false,
275 enable_pipelining: false,
276 ignore_8bit_checks: false,
277 }
278 }
279
280 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
282 self.ignore_8bit_checks = enable;
283 }
284
285 pub fn is_connected(&self) -> bool {
286 self.socket.is_some()
287 }
288
289 pub fn set_enable_rset(&mut self, enable: bool) {
290 self.enable_rset = enable;
291 }
292
293 pub fn set_enable_pipelining(&mut self, enable: bool) {
294 self.enable_pipelining = enable;
295 }
296
297 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
298 self.tracer.replace(tracer);
299 }
300
301 pub fn timeouts(&self) -> &SmtpClientTimeouts {
302 &self.timeouts
303 }
304
305 async fn read_line(
306 &mut self,
307 timeout_duration: Duration,
308 cmd: Option<&Command>,
309 ) -> Result<String, ClientError> {
310 let mut too_long = false;
311 loop {
312 let mut iter = self.read_buffer.iter().enumerate();
313 while let Some((i, &b)) = iter.next() {
314 if b != b'\r' {
315 continue;
316 }
317 if let Some((_, b'\n')) = iter.next() {
318 if too_long {
319 self.read_buffer.drain(0..i + 2);
320
321 if let Some(tracer) = &self.tracer {
322 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
323 level: Level::ERROR,
324 message: "Response too long".to_string(),
325 });
326 }
327
328 return Err(ClientError::ResponseTooLong);
329 }
330
331 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
332 self.read_buffer.drain(0..i + 2);
333 return Ok(line?);
334 }
335 }
336 if self.read_buffer.len() > MAX_LINE_LEN {
337 self.read_buffer.clear();
338 too_long = true;
339 }
340
341 let mut data = [0u8; MAX_LINE_LEN];
343 let size = match self.socket.as_mut() {
344 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
345 Ok(Ok(size)) => size,
346 Ok(Err(err)) => {
347 self.socket.take();
348 if let Some(tracer) = &self.tracer {
349 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
350 level: Level::ERROR,
351 message: format!("Error during read: {err:#}"),
352 });
353 tracer.trace_event(SmtpClientTraceEvent::Closed);
354 }
355 return Err(ClientError::ReadError {
356 command: cmd.cloned(),
357 error: format!("{err:#}"),
358 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
359 });
360 }
361 Err(_) => {
362 if timeout_duration != Duration::ZERO {
367 self.socket.take();
368 }
369 if let Some(tracer) = &self.tracer {
370 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
371 level: Level::ERROR,
372 message: format!("Read Timeout after {timeout_duration:?}"),
373 });
374 if self.socket.is_none() {
375 tracer.trace_event(SmtpClientTraceEvent::Closed);
376 }
377 }
378 return Err(ClientError::TimeOutResponse {
379 command: cmd.cloned(),
380 duration: timeout_duration,
381 });
382 }
383 },
384 None => {
385 return Err(ClientError::ReadError {
386 command: cmd.cloned(),
387 error: "the socket was closed in response to an earlier issue".to_string(),
388 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
389 });
390 }
391 };
392 if size == 0 {
393 self.socket.take();
394 if let Some(tracer) = &self.tracer {
395 tracer.trace_event(SmtpClientTraceEvent::Closed);
396 }
397 return Err(ClientError::ReadError {
398 command: cmd.cloned(),
399 error: "Connection closed by peer".to_string(),
400 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
401 });
402 }
403 if let Some(tracer) = &self.tracer {
404 tracer.lazy_trace(&ReadTracer {
405 data: &data[0..size],
406 });
407 }
408 self.read_buffer.extend_from_slice(&data[0..size]);
409 }
410 }
411
412 pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
420 match self.read_response(None, Duration::ZERO).await {
421 Ok(response) => Ok(Some(response)),
423 Err(ClientError::TimeOutResponse { .. }) => Ok(None),
425 Err(err) => Err(err),
427 }
428 }
429
430 pub async fn read_response(
431 &mut self,
432 command: Option<&Command>,
433 timeout_duration: Duration,
434 ) -> Result<Response, ClientError> {
435 if let Some(sock) = self.socket.as_mut() {
436 if let Err(err) = sock.flush().await {
437 self.socket.take();
438 if let Some(tracer) = &self.tracer {
439 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
440 level: Level::ERROR,
441 message: format!("Error during flush: {err:#}"),
442 });
443 tracer.trace_event(SmtpClientTraceEvent::Closed);
444 }
445 return Err(ClientError::FlushError {
446 error: format!("{err:#}"),
447 });
448 }
449 }
450
451 let mut line = self.read_line(timeout_duration, command).await?;
452 tracing::trace!("recv<-{}: {line}", self.hostname);
453 let mut parsed = parse_response_line(&line)?;
454 let mut response_builder = ResponseBuilder::new(&parsed);
455
456 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
457 while !parsed.is_final {
458 line = self
459 .read_line(subsequent_line_timeout_duration, command)
460 .await?;
461 parsed = parse_response_line(&line)?;
462 response_builder
463 .add_line(&parsed)
464 .map_err(ClientError::MalformedResponseLine)?;
465 }
466
467 let response = response_builder.build(command.map(|cmd| cmd.encode().to_string()));
468
469 tracing::trace!("{}: {response:?}", self.hostname);
470
471 Ok(response)
472 }
473
474 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
475 self.write_command_request(command).await?;
476 self.read_response(Some(command), command.client_timeout(&self.timeouts))
477 .await
478 }
479
480 async fn write_all_with_timeout<F, G>(
484 &mut self,
485 timeout_duration: Duration,
486 bytes: &[u8],
487 make_timeout_err: F,
488 make_write_err: G,
489 ) -> Result<(), ClientError>
490 where
491 F: FnOnce() -> ClientError,
492 G: FnOnce(String) -> ClientError,
493 {
494 match self.socket.as_mut() {
495 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
496 Ok(Ok(response)) => Ok(response),
497 Ok(Err(err)) => {
498 self.socket.take();
499 if let Some(tracer) = &self.tracer {
500 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
501 level: Level::ERROR,
502 message: format!("Error during write: {err:#}"),
503 });
504 tracer.trace_event(SmtpClientTraceEvent::Closed);
505 }
506 Err(make_write_err(format!("{err:#}")))
507 }
508 Err(_) => {
509 self.socket.take();
510 if let Some(tracer) = &self.tracer {
511 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
512 level: Level::ERROR,
513 message: format!("Write Timeout after {timeout_duration:?}"),
514 });
515 tracer.trace_event(SmtpClientTraceEvent::Closed);
516 }
517 Err(make_timeout_err())
518 }
519 },
520 None => Err(make_write_err(
521 "the socket was closed in response to an earlier issue".to_string(),
522 )),
523 }
524 }
525
526 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
527 let total_timeout: Duration = commands
528 .iter()
529 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
530 .sum();
531
532 let mut lines: Vec<String> = vec![];
533 let mut all: Vec<u8> = vec![];
534 for cmd in commands {
535 let line = cmd.encode();
536 all.extend_from_slice(&line);
537 lines.push(line.to_string());
538 }
539 tracing::trace!(
540 "send->{}: (PIPELINE) {}",
541 self.hostname,
542 all.as_bstr().escape_bytes()
543 );
544 if self.socket.is_some() {
545 if let Some(tracer) = &self.tracer {
546 for line in &lines {
549 WriteTracer::trace(tracer, line);
550 }
551 }
552 }
553 self.write_all_with_timeout(
554 total_timeout,
555 &all,
556 || ClientError::TimeOutRequest {
557 duration: total_timeout,
558 commands: commands.to_vec(),
559 },
560 |error| ClientError::WriteError {
561 error,
562 commands: commands.to_vec(),
563 },
564 )
565 .await
566 }
567
568 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
569 let line = command.encode();
570 tracing::trace!("send->{}: {line}", self.hostname);
571 if self.socket.is_some() {
572 if let Some(tracer) = &self.tracer {
573 WriteTracer::trace(tracer, &line.to_string());
574 }
575 }
576
577 let timeout_duration = command.client_timeout_request(&self.timeouts);
578 self.write_all_with_timeout(
579 timeout_duration,
580 &line,
581 || ClientError::TimeOutRequest {
582 duration: timeout_duration,
583 commands: vec![command.clone()],
584 },
585 |error| ClientError::WriteError {
586 error,
587 commands: vec![command.clone()],
588 },
589 )
590 .await
591 }
592
593 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
594 if self.socket.is_some() {
595 if let Some(tracer) = &self.tracer {
596 BinWriteTracer::trace(tracer, data);
597 }
598 }
599 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
600 self.write_all_with_timeout(
601 timeout_duration,
602 data,
603 || ClientError::TimeOutData,
604 |error| ClientError::WriteError {
605 error,
606 commands: vec![],
607 },
608 )
609 .await
610 }
611
612 pub async fn pipeline_commands(
628 &mut self,
629 commands: Vec<Command>,
630 ) -> Vec<Result<Response, ClientError>> {
631 let mut results: Vec<Result<Response, ClientError>> = vec![];
632
633 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
634 if pipeline {
635 if let Err(err) = self.write_pipeline_request(&commands).await {
636 let err: ClientError = err;
637 results.push(Err(err.clone()));
638 while results.len() < commands.len() {
639 results.push(Err(err.clone()));
641 }
642 return results;
643 }
644
645 for cmd in &commands {
647 results.push(
648 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
649 .await,
650 );
651 }
652 return results;
653 }
654
655 for cmd in &commands {
656 if let Err(err) = self.write_command_request(cmd).await {
657 let err: ClientError = err;
658 results.push(Err(err.clone()));
659 while results.len() < commands.len() {
660 results.push(Err(err.clone()));
662 }
663 return results;
664 }
665 results.push(
668 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
669 .await,
670 );
671 }
672 results
673 }
674
675 pub async fn ehlo_lhlo(
676 &mut self,
677 ehlo_name: &str,
678 use_lmtp: bool,
679 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
680 if use_lmtp {
681 self.lhlo(ehlo_name).await
682 } else {
683 self.ehlo(ehlo_name).await
684 }
685 }
686
687 pub async fn lhlo(
688 &mut self,
689 ehlo_name: &str,
690 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
691 let response = self
692 .send_command(&Command::Lhlo(Domain::DomainName(
693 ehlo_name
694 .parse::<DomainString>()
695 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
696 )))
697 .await?;
698 self.ehlo_common(response)
699 }
700
701 pub async fn ehlo(
702 &mut self,
703 ehlo_name: &str,
704 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
705 let response = self
706 .send_command(&Command::Ehlo(Domain::DomainName(
707 ehlo_name
708 .parse::<DomainString>()
709 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
710 )))
711 .await?;
712 self.ehlo_common(response)
713 }
714
715 fn ehlo_common(
716 &mut self,
717 response: Response,
718 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
719 if response.code != 250 {
720 return Err(ClientError::Rejected(response));
721 }
722
723 let mut capabilities = HashMap::new();
724
725 for line in response.content.lines().skip(1) {
726 let mut fields = line.splitn(2, ' ');
727 if let Some(name) = fields.next() {
728 let param = fields.next().map(|s| s.to_string());
729 let cap = EsmtpCapability {
730 name: name.to_string(),
731 param,
732 };
733 capabilities.insert(name.to_ascii_uppercase(), cap);
734 }
735 }
736
737 self.capabilities = capabilities;
738 Ok(&self.capabilities)
739 }
740
741 pub async fn auth_plain(
742 &mut self,
743 username: &str,
744 password: Option<&str>,
745 ) -> Result<(), ClientError> {
746 let password = password.unwrap_or("");
749 let payload = format!("\x00{username}\x00{password}");
750 let payload = data_encoding::BASE64.encode(payload.as_bytes());
751
752 let response = self
753 .send_command(&Command::Auth {
754 sasl_mech: "PLAIN".to_string(),
755 initial_response: Some(payload),
756 })
757 .await?;
758
759 if response.code != 235 {
760 return Err(ClientError::Rejected(response));
761 }
762
763 Ok(())
764 }
765
766 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
772 let resp = self.send_command(&Command::StartTls).await?;
773 if resp.code != 220 {
774 return Err(ClientError::Rejected(resp));
775 }
776
777 let mut handshake_error = None;
778 let mut tls_info = TlsInformation::default();
779
780 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
781 || !options.dane_tlsa.is_empty()
782 {
783 let connector = options
784 .build_openssl_connector(&self.hostname)
785 .map_err(|error| ClientError::FailedToBuildConnector {
786 error: error.to_string(),
787 })?;
788 let ssl = connector.into_ssl(self.hostname.as_str())?;
789
790 let (stream, dup_stream) = match self.socket.take() {
791 Some(s) => {
792 let d = s.try_dup();
793 (s, d)
794 }
795 None => return Err(ClientError::NotConnected),
796 };
797
798 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
799
800 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
801 handshake_error.replace(format!("{err:#}"));
802 }
803
804 tls_info.provider_name = "openssl".to_string();
805 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
806 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
807 None => String::new(),
808 };
809 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
810
811 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
812 tls_info.subject_name = subject_name(&cert);
813 }
814 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
815 if let Some(cert) = &authority.cert {
816 tls_info.subject_name = subject_name(cert);
817 }
818 }
819
820 match (&handshake_error, dup_stream) {
821 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
822 drop(ssl_stream);
827 Box::new(dup_stream)
828 }
829 _ => Box::new(ssl_stream),
830 }
831 } else {
832 tls_info.provider_name = "rustls".to_string();
833 let connector = options.build_tls_connector().await.map_err(|error| {
834 ClientError::FailedToBuildConnector {
835 error: error.to_string(),
836 }
837 })?;
838 let server_name = parse_server_name(self.hostname.as_str())?;
839
840 match connector
841 .connect(
842 server_name,
843 match self.socket.take() {
844 Some(s) => s,
845 None => return Err(ClientError::NotConnected),
846 },
847 )
848 .into_fallible()
849 .await
850 {
851 Ok(stream) => {
852 let (_, conn) = stream.get_ref();
853 tls_info.cipher = match conn.negotiated_cipher_suite() {
854 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
855 None => String::new(),
856 };
857 tls_info.protocol_version = match conn.protocol_version() {
858 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
859 None => String::new(),
860 };
861
862 if let Some(certs) = conn.peer_certificates() {
863 let peer_cert = &certs[0];
864 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
865 tls_info.subject_name = subject_name(&cert);
866 }
867 }
868
869 Box::new(stream)
870 }
871 Err((err, stream)) => {
872 handshake_error.replace(format!("{err:#}"));
873 stream
874 }
875 }
876 };
877
878 if let Some(tracer) = &self.tracer {
879 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
880 level: Level::INFO,
881 message: match &handshake_error {
882 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
883 None => format!("STARTTLS handshake -> {tls_info:?}"),
884 },
885 });
886 }
887
888 self.socket.replace(stream);
889 Ok(match handshake_error {
890 Some(error) => TlsStatus::FailedHandshake(error),
891 None => TlsStatus::Info(tls_info),
892 })
893 }
894
895 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
896 &mut self,
897 sender: SENDER,
898 recipient: RECIP,
899 data: B,
900 ) -> Result<Response, ClientError> {
901 let recipient: ForwardPath = recipient.into();
902 let status = self
903 .send_mail_multi_recip(sender, vec![recipient], data)
904 .await?;
905 Ok(status.response)
906 }
907
908 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
909 &mut self,
910 sender: SENDER,
911 recipient_list: Vec<ForwardPath>,
912 data: B,
913 ) -> Result<BatchSendSuccess, ClientError> {
914 let sender = sender.into();
915
916 let data: &[u8] = data.as_ref();
917 let stuffed;
918
919 let data = match apply_dot_stuffing(data) {
920 Some(d) => {
921 stuffed = d;
922 &stuffed
923 }
924 None => data,
925 };
926
927 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
928 let envelope_is_8bit =
929 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
930
931 let mut mail_from_params = vec![];
932 if data_is_8bit {
933 if self.capabilities.contains_key("8BITMIME") {
934 mail_from_params.push(EsmtpParameter {
935 name: "BODY".to_string(),
936 value: Some("8BITMIME".to_string()),
937 });
938 } else if !self.ignore_8bit_checks {
939 return Err(ClientError::Rejected(Response {
940 code: 554,
941 command: None,
942 enhanced_code: Some(EnhancedStatusCode {
943 class: 5,
944 subject: 6,
945 detail: 3,
946 }),
947 content: "KumoMTA internal: DATA is 8bit, destination does \
948 not support 8BITMIME. Conversion via msg:check_fix_conformance \
949 during reception is required"
950 .to_string(),
951 }));
952 }
953 }
954
955 if envelope_is_8bit {
956 if self.capabilities.contains_key("SMTPUTF8") {
957 mail_from_params.push(EsmtpParameter {
958 name: "SMTPUTF8".to_string(),
959 value: None,
960 });
961 } else if !self.ignore_8bit_checks {
962 return Err(ClientError::Rejected(Response {
963 code: 554,
964 command: None,
965 enhanced_code: Some(EnhancedStatusCode {
966 class: 5,
967 subject: 6,
968 detail: 7,
969 }),
970 content: "KumoMTA internal: envelope is 8bit, destination does \
971 not support SMTPUTF8."
972 .to_string(),
973 }));
974 }
975 }
976
977 let mut commands = vec![];
978
979 let used_rset = self.use_rset;
989 if self.use_rset {
990 commands.push(Command::Rset);
991 }
992 commands.push(Command::MailFrom {
993 address: sender,
994 parameters: mail_from_params,
995 });
996
997 for recipient in &recipient_list {
998 commands.push(Command::RcptTo {
999 address: recipient.clone(),
1000 parameters: vec![],
1001 });
1002 }
1003 commands.push(Command::Data);
1004
1005 self.use_rset = true;
1008
1009 let mut responses = self.pipeline_commands(commands).await;
1010
1011 let is_err = responses.iter().any(|r| r.is_err());
1019
1020 if used_rset {
1021 let rset_resp = responses.remove(0)?;
1022 if rset_resp.code != 250 {
1023 return Err(ClientError::Rejected(rset_resp));
1024 }
1025 }
1026
1027 let mail_resp = responses.remove(0)?;
1028 if is_err && mail_resp.code != 250 {
1029 return Err(ClientError::Rejected(mail_resp));
1030 }
1031
1032 let mut rcpt_responses = vec![];
1033 for _ in &recipient_list {
1034 rcpt_responses.push(responses.remove(0)?);
1035 }
1036
1037 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1038 return Err(ClientError::RejectedBatch(rcpt_responses));
1039 }
1040
1041 let data_resp = responses.remove(0)?;
1042 if is_err && data_resp.code != 354 {
1043 return Err(ClientError::Rejected(data_resp));
1044 }
1045
1046 if data_resp.code == 354
1047 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1048 {
1049 self.write_data_with_timeout(b".\r\n").await?;
1058 let data_dot = Command::DataDot;
1059 let _ = self
1061 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1062 .await?;
1063
1064 }
1068
1069 if mail_resp.code != 250 {
1070 return Err(ClientError::Rejected(mail_resp));
1071 }
1072 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1073 if rcpt_responses.len() == 1 {
1074 return Err(ClientError::Rejected(
1075 rcpt_responses.pop().expect("have at least one"),
1076 ));
1077 }
1078 return Err(ClientError::RejectedBatch(rcpt_responses));
1079 }
1080 if data_resp.code != 354 {
1081 return Err(ClientError::Rejected(data_resp));
1082 }
1083
1084 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1085
1086 tracing::trace!("message data is {} bytes", data.len());
1087
1088 self.write_data_with_timeout(data).await?;
1089
1090 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1091
1092 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1093
1094 self.write_data_with_timeout(marker.as_bytes()).await?;
1095
1096 let data_dot = Command::DataDot;
1097 let resp = self
1098 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1099 .await?;
1100 if resp.code != 250 {
1101 return Err(ClientError::Rejected(resp));
1102 }
1103
1104 self.use_rset = self.enable_rset;
1107
1108 Ok(BatchSendSuccess {
1109 response: resp,
1110 rcpt_responses,
1111 })
1112 }
1113}
1114
1115#[derive(Debug)]
1116pub struct BatchSendSuccess {
1117 pub response: Response,
1118 pub rcpt_responses: Vec<Response>,
1119}
1120
1121#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1122pub enum TlsStatus {
1123 FailedHandshake(String),
1124 Info(TlsInformation),
1125}
1126
1127#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1128pub struct TlsInformation {
1129 pub cipher: String,
1130 pub protocol_version: String,
1131 pub subject_name: Vec<String>,
1132 pub provider_name: String,
1133}
1134
1135impl Drop for SmtpClient {
1136 fn drop(&mut self) {
1137 if let Some(tracer) = &self.tracer {
1138 if self.socket.is_some() {
1139 tracer.trace_event(SmtpClientTraceEvent::Closed);
1140 }
1141 }
1142 }
1143}
1144fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1145 if line.len() < 4 {
1146 return Err(ClientError::MalformedResponseLine(line.to_string()));
1147 }
1148
1149 match line.as_bytes()[3] {
1150 b' ' | b'-' => match line[0..3].parse::<u16>() {
1151 Ok(code) => Ok(ResponseLine {
1152 code,
1153 is_final: line.as_bytes()[3] == b' ',
1154 content: &line[4..],
1155 }),
1156 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1157 },
1158 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1159 }
1160}
1161
1162fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1163 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1164
1165 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1166 return None;
1167 }
1168
1169 let mut stuffed = vec![];
1170 if data.starts_with(b".") {
1171 stuffed.push(b'.');
1172 }
1173 let mut last_idx = 0;
1174 for i in LFDOT.find_iter(data) {
1175 stuffed.extend_from_slice(&data[last_idx..=i]);
1176 stuffed.push(b'.');
1177 last_idx = i + 1;
1178 }
1179 stuffed.extend_from_slice(&data[last_idx..]);
1180 Some(stuffed)
1181}
1182
1183pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1190 let mut subject_name = vec![];
1191 for entry in cert.subject_name().entries() {
1192 if let Ok(obj) = entry.object().nid().short_name() {
1193 if let Ok(name) = entry.data().as_utf8() {
1194 subject_name.push(format!("{obj}={name}"));
1195 }
1196 }
1197 }
1198 subject_name
1199}
1200
1201fn parse_server_name(input: &str) -> Result<ServerName<'static>, ClientError> {
1202 match IpAddr::from_str(input) {
1203 Ok(ip) => Ok(ServerName::IpAddress(ip.into())),
1204 Err(_) => {
1205 let name = Name::from_str_relaxed(input)
1206 .map_err(|_| ClientError::InvalidDnsName(input.to_string()))?;
1207 ServerName::try_from(name.to_ascii())
1208 .map_err(|_| ClientError::InvalidDnsName(name.to_ascii()))
1209 }
1210 }
1211}
1212
1213#[cfg(test)]
1214mod test {
1215 use super::*;
1216 use crate::parser::{EnvelopeAddress, ReversePath};
1217
1218 #[test]
1219 fn test_stuffing() {
1220 assert_eq!(apply_dot_stuffing(b"foo"), None);
1221 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1222 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1223 assert_eq!(
1224 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1225 b"foo\n..bar\n...baz\n"
1226 );
1227 }
1228
1229 #[test]
1253 fn response_line_parsing() {
1254 assert_eq!(
1255 parse_response_line("220 woot").unwrap(),
1256 ResponseLine {
1257 code: 220,
1258 is_final: true,
1259 content: "woot"
1260 }
1261 );
1262 assert_eq!(
1263 parse_response_line("220-woot").unwrap(),
1264 ResponseLine {
1265 code: 220,
1266 is_final: false,
1267 content: "woot"
1268 }
1269 );
1270
1271 assert!(matches!(
1272 parse_response_line("220_woot"),
1273 Err(ClientError::MalformedResponseLine(_))
1274 ));
1275 assert!(matches!(
1276 parse_response_line("not really"),
1277 Err(ClientError::MalformedResponseLine(_))
1278 ));
1279 }
1280
1281 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1282 let mut parsed = parse_response_line(lines[0])?;
1283 let mut b = ResponseBuilder::new(&parsed);
1284 for line in &lines[1..] {
1285 parsed = parse_response_line(line)?;
1286 b.add_line(&parsed)
1287 .map_err(ClientError::MalformedResponseLine)?;
1288 }
1289 assert!(parsed.is_final);
1290 Ok(b.build(None))
1291 }
1292
1293 #[test]
1294 fn multi_line_response() {
1295 assert_eq!(
1296 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1297 Response {
1298 code: 220,
1299 enhanced_code: None,
1300 content: "woot\nmore\ndone".to_string(),
1301 command: None
1302 }
1303 );
1304
1305 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1306 assert!(
1307 matches!(
1308 res,
1309 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1310 ),
1311 "got error {res:?}"
1312 );
1313
1314 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1315 assert!(
1316 matches!(
1317 res,
1318 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1319 ),
1320 "got error {res:?}"
1321 );
1322
1323 assert_eq!(
1324 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1325 Response {
1326 code: 220,
1327 enhanced_code: Some(EnhancedStatusCode {
1328 class: 4,
1329 subject: 1,
1330 detail: 0
1331 }),
1332 content: "woot\nmore\ndone".to_string(),
1333 command: None
1334 }
1335 );
1336
1337 assert_eq!(
1339 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1340 Response {
1341 code: 220,
1342 enhanced_code: Some(EnhancedStatusCode {
1343 class: 4,
1344 subject: 1,
1345 detail: 0
1346 }),
1347 content: "woot\nmore\ndone".to_string(),
1348 command: None
1349 }
1350 );
1351
1352 assert_eq!(
1354 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1355 Response {
1356 code: 220,
1357 enhanced_code: Some(EnhancedStatusCode {
1358 class: 4,
1359 subject: 1,
1360 detail: 0
1361 }),
1362 content: "woot\nmore\n5.5.5 done".to_string(),
1363 command: None
1364 }
1365 );
1366 }
1367
1368 #[test]
1369 fn test_extract_hostname() {
1370 assert_eq!(extract_hostname("foo"), "foo");
1371 assert_eq!(extract_hostname("foo."), "foo");
1372 assert_eq!(extract_hostname("foo:25"), "foo");
1373 assert_eq!(extract_hostname("foo.:25"), "foo");
1374 assert_eq!(extract_hostname("[foo]:25"), "foo");
1375 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1376 assert_eq!(extract_hostname("[::1]:25"), "::1");
1377 assert_eq!(extract_hostname("::1:25"), "::1");
1378 }
1379
1380 #[test]
1381 fn test_format_error_command() {
1382 assert_eq!(
1383 format!(
1384 "{:#}",
1385 ClientError::TimeOutRequest {
1386 commands: vec![Command::DataDot],
1387 duration: Duration::from_secs(10),
1388 }
1389 ),
1390 "Timed Out after 10s writing cmd=."
1391 );
1392 assert_eq!(
1393 format!(
1394 "{:#}",
1395 ClientError::TimeOutResponse {
1396 command: Some(Command::MailFrom {
1397 address: {
1398 let EnvelopeAddress::Path(p) =
1399 EnvelopeAddress::parse("user@host").unwrap()
1400 else {
1401 panic!("expected Path")
1402 };
1403 ReversePath::Path(p)
1404 },
1405 parameters: vec![],
1406 }),
1407 duration: Duration::from_secs(10),
1408 }
1409 ),
1410 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1411 );
1412 }
1413
1414 #[test]
1415 fn test_issue_533() {
1416 let _name = parse_server_name("münchen.de").unwrap();
1417 }
1418}