1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::pkey::PKey;
11use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
12use openssl::x509::{X509Ref, X509};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::IpAddr;
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock};
18use std::time::Duration;
19use thiserror::Error;
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpStream, ToSocketAddrs};
22use tokio::time::timeout;
23use tokio_rustls::rustls::pki_types::ServerName;
24use tracing::Level;
25
26pub use crate::tls::TlsOptions;
27pub use {openssl, tokio_rustls};
28
29const MAX_LINE_LEN: usize = 4096;
30
31#[derive(Error, Debug, Clone)]
32pub enum ClientError {
33 #[error("response is not UTF8")]
34 Utf8(#[from] std::string::FromUtf8Error),
35 #[error("Malformed Response: {0}")]
36 MalformedResponseLine(String),
37 #[error("Response line is too long")]
38 ResponseTooLong,
39 #[error("Not connected")]
40 NotConnected,
41 #[error("Command rejected {0:?}")]
42 Rejected(Response),
43 #[error("Commands rejected {0:?}")]
44 RejectedBatch(Vec<Response>),
45 #[error("STARTTLS: {0} is not a valid DNS name")]
46 InvalidDnsName(String),
47 #[error("Invalid client certificate configured: {error:?}")]
48 FailedToBuildConnector { error: String },
49 #[error("Timed Out waiting {duration:?} for response to {command:?}")]
50 TimeOutResponse {
51 command: Option<Command>,
52 duration: Duration,
53 },
54 #[error("Timed Out writing {duration:?} {commands:?}")]
55 TimeOutRequest {
56 commands: Vec<Command>,
57 duration: Duration,
58 },
59 #[error("Error {error} reading response to {command:?}")]
60 ReadError {
61 command: Option<Command>,
62 error: String,
63 partial: String,
64 },
65 #[error("Error {error} flushing send buffer")]
66 FlushError { error: String },
67 #[error("Error {error} writing {commands:?}")]
68 WriteError {
69 commands: Vec<Command>,
70 error: String,
71 },
72 #[error("Timed Out sending message payload data")]
73 TimeOutData,
74 #[error("SSL Error: {0}")]
75 SslErrorStack(#[from] openssl::error::ErrorStack),
76 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
77 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
78}
79
80impl ClientError {
81 pub fn command(&self) -> Option<String> {
83 match self {
84 Self::TimeOutResponse {
85 command: Some(command),
86 ..
87 }
88 | Self::ReadError {
89 command: Some(command),
90 ..
91 } => Some(command.encode()),
92 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
93 if !commands.is_empty() =>
94 {
95 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
96 Some(commands.join(""))
97 }
98 _ => None,
99 }
100 }
101
102 pub fn was_due_to_message(&self) -> bool {
114 match self {
115 Self::Utf8(_)
116 | Self::MalformedResponseLine(_)
117 | Self::ResponseTooLong
118 | Self::NotConnected
119 | Self::InvalidDnsName(_)
120 | Self::FailedToBuildConnector { .. }
121 | Self::TimeOutResponse { .. }
122 | Self::TimeOutRequest { .. }
123 | Self::ReadError { .. }
124 | Self::FlushError { .. }
125 | Self::WriteError { .. }
126 | Self::TimeOutData
127 | Self::SslErrorStack(_)
128 | Self::NoUsableDaneTlsa { .. } => false,
129 Self::Rejected(response) => response.was_due_to_message(),
130 Self::RejectedBatch(responses) => match responses.len() {
131 1 => responses[0].was_due_to_message(),
132 _ => false,
133 },
134 }
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct EsmtpCapability {
140 pub name: String,
141 pub param: Option<String>,
142}
143
144#[derive(Clone, Debug)]
145pub enum SmtpClientTraceEvent {
146 Closed,
147 Read(Vec<u8>),
148 Write(String),
149 Diagnostic {
150 level: tracing::Level,
151 message: String,
152 },
153}
154
155pub trait DeferredTracer {
156 fn trace(&self) -> SmtpClientTraceEvent;
157}
158
159pub trait SmtpClientTracer: std::fmt::Debug {
160 fn trace_event(&self, event: SmtpClientTraceEvent);
161 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
162}
163
164struct WriteTracer<'a> {
166 data: &'a str,
167}
168impl DeferredTracer for WriteTracer<'_> {
169 fn trace(&self) -> SmtpClientTraceEvent {
170 SmtpClientTraceEvent::Write(self.data.to_string())
171 }
172}
173impl<'a> WriteTracer<'a> {
174 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
175 tracer.lazy_trace(&Self { data });
176 }
177}
178
179struct BinWriteTracer<'a> {
180 data: &'a [u8],
181}
182impl DeferredTracer for BinWriteTracer<'_> {
183 fn trace(&self) -> SmtpClientTraceEvent {
184 let data = String::from_utf8_lossy(self.data).to_string();
185 SmtpClientTraceEvent::Write(data)
186 }
187}
188impl<'a> BinWriteTracer<'a> {
189 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
190 tracer.lazy_trace(&Self { data });
191 }
192}
193
194struct ReadTracer<'a> {
197 data: &'a [u8],
198}
199impl DeferredTracer for ReadTracer<'_> {
200 fn trace(&self) -> SmtpClientTraceEvent {
201 SmtpClientTraceEvent::Read(self.data.to_vec())
202 }
203}
204
205#[derive(Debug)]
206pub struct SmtpClient {
207 socket: Option<BoxedAsyncReadAndWrite>,
208 hostname: String,
209 capabilities: HashMap<String, EsmtpCapability>,
210 read_buffer: Vec<u8>,
211 timeouts: SmtpClientTimeouts,
212 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
213 use_rset: bool,
214 enable_rset: bool,
215 enable_pipelining: bool,
216 ignore_8bit_checks: bool,
217}
218
219fn extract_hostname(hostname: &str) -> &str {
220 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
222 let hostname = if fields.len() == 2 {
223 fields[1]
224 } else {
225 hostname
226 };
227
228 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
229 &hostname[1..hostname.len() - 1]
230 } else {
231 hostname
232 };
233
234 hostname.strip_suffix(".").unwrap_or(hostname)
236}
237
238impl SmtpClient {
239 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
240 addr: A,
241 timeouts: SmtpClientTimeouts,
242 ) -> std::io::Result<Self> {
243 let stream = TcpStream::connect(addr.clone()).await?;
244 stream.set_nodelay(true)?;
246 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
247 }
248
249 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
250 stream: S,
251 peer_hostname: H,
252 timeouts: SmtpClientTimeouts,
253 ) -> Self {
254 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
255
256 Self {
257 socket: Some(Box::new(stream)),
258 hostname,
259 capabilities: HashMap::new(),
260 read_buffer: Vec::with_capacity(1024),
261 timeouts,
262 tracer: None,
263 use_rset: false,
264 enable_rset: false,
265 enable_pipelining: false,
266 ignore_8bit_checks: false,
267 }
268 }
269
270 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
272 self.ignore_8bit_checks = enable;
273 }
274
275 pub fn is_connected(&self) -> bool {
276 self.socket.is_some()
277 }
278
279 pub fn set_enable_rset(&mut self, enable: bool) {
280 self.enable_rset = enable;
281 }
282
283 pub fn set_enable_pipelining(&mut self, enable: bool) {
284 self.enable_pipelining = enable;
285 }
286
287 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
288 self.tracer.replace(tracer);
289 }
290
291 pub fn timeouts(&self) -> &SmtpClientTimeouts {
292 &self.timeouts
293 }
294
295 async fn read_line(
296 &mut self,
297 timeout_duration: Duration,
298 cmd: Option<&Command>,
299 ) -> Result<String, ClientError> {
300 let mut too_long = false;
301 loop {
302 let mut iter = self.read_buffer.iter().enumerate();
303 while let Some((i, &b)) = iter.next() {
304 if b != b'\r' {
305 continue;
306 }
307 if let Some((_, b'\n')) = iter.next() {
308 if too_long {
309 self.read_buffer.drain(0..i + 2);
310
311 if let Some(tracer) = &self.tracer {
312 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
313 level: Level::ERROR,
314 message: "Response too long".to_string(),
315 });
316 }
317
318 return Err(ClientError::ResponseTooLong);
319 }
320
321 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
322 self.read_buffer.drain(0..i + 2);
323 return Ok(line?);
324 }
325 }
326 if self.read_buffer.len() > MAX_LINE_LEN {
327 self.read_buffer.clear();
328 too_long = true;
329 }
330
331 let mut data = [0u8; MAX_LINE_LEN];
333 let size = match self.socket.as_mut() {
334 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
335 Ok(Ok(size)) => size,
336 Ok(Err(err)) => {
337 self.socket.take();
338 if let Some(tracer) = &self.tracer {
339 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
340 level: Level::ERROR,
341 message: format!("Error during read: {err:#}"),
342 });
343 tracer.trace_event(SmtpClientTraceEvent::Closed);
344 }
345 return Err(ClientError::ReadError {
346 command: cmd.cloned(),
347 error: format!("{err:#}"),
348 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
349 });
350 }
351 Err(_) => {
352 self.socket.take();
353 if let Some(tracer) = &self.tracer {
354 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
355 level: Level::ERROR,
356 message: format!("Read Timeout after {timeout_duration:?}"),
357 });
358 tracer.trace_event(SmtpClientTraceEvent::Closed);
359 }
360 return Err(ClientError::TimeOutResponse {
361 command: cmd.cloned(),
362 duration: timeout_duration,
363 });
364 }
365 },
366 None => {
367 return Err(ClientError::ReadError {
368 command: cmd.cloned(),
369 error: "the socket was closed in response to an earlier issue".to_string(),
370 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
371 });
372 }
373 };
374 if size == 0 {
375 self.socket.take();
376 if let Some(tracer) = &self.tracer {
377 tracer.trace_event(SmtpClientTraceEvent::Closed);
378 }
379 return Err(ClientError::ReadError {
380 command: cmd.cloned(),
381 error: "Connection closed by peer".to_string(),
382 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
383 });
384 }
385 if let Some(tracer) = &self.tracer {
386 tracer.lazy_trace(&ReadTracer {
387 data: &data[0..size],
388 });
389 }
390 self.read_buffer.extend_from_slice(&data[0..size]);
391 }
392 }
393
394 pub async fn read_response(
395 &mut self,
396 command: Option<&Command>,
397 timeout_duration: Duration,
398 ) -> Result<Response, ClientError> {
399 if let Some(sock) = self.socket.as_mut() {
400 if let Err(err) = sock.flush().await {
401 self.socket.take();
402 if let Some(tracer) = &self.tracer {
403 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
404 level: Level::ERROR,
405 message: format!("Error during flush: {err:#}"),
406 });
407 tracer.trace_event(SmtpClientTraceEvent::Closed);
408 }
409 return Err(ClientError::FlushError {
410 error: format!("{err:#}"),
411 });
412 }
413 }
414
415 let mut line = self.read_line(timeout_duration, command).await?;
416 tracing::trace!("recv<-{}: {line}", self.hostname);
417 let mut parsed = parse_response_line(&line)?;
418 let mut response_builder = ResponseBuilder::new(&parsed);
419
420 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
421 while !parsed.is_final {
422 line = self
423 .read_line(subsequent_line_timeout_duration, command)
424 .await?;
425 parsed = parse_response_line(&line)?;
426 response_builder
427 .add_line(&parsed)
428 .map_err(ClientError::MalformedResponseLine)?;
429 }
430
431 let response = response_builder.build(command.map(|cmd| cmd.encode()));
432
433 tracing::trace!("{}: {response:?}", self.hostname);
434
435 Ok(response)
436 }
437
438 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
439 self.write_command_request(command).await?;
440 self.read_response(Some(command), command.client_timeout(&self.timeouts))
441 .await
442 }
443
444 async fn write_all_with_timeout<F, G>(
448 &mut self,
449 timeout_duration: Duration,
450 bytes: &[u8],
451 make_timeout_err: F,
452 make_write_err: G,
453 ) -> Result<(), ClientError>
454 where
455 F: FnOnce() -> ClientError,
456 G: FnOnce(String) -> ClientError,
457 {
458 match self.socket.as_mut() {
459 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
460 Ok(Ok(response)) => Ok(response),
461 Ok(Err(err)) => {
462 self.socket.take();
463 if let Some(tracer) = &self.tracer {
464 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
465 level: Level::ERROR,
466 message: format!("Error during write: {err:#}"),
467 });
468 tracer.trace_event(SmtpClientTraceEvent::Closed);
469 }
470 Err(make_write_err(format!("{err:#}")))
471 }
472 Err(_) => {
473 self.socket.take();
474 if let Some(tracer) = &self.tracer {
475 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
476 level: Level::ERROR,
477 message: format!("Write Timeout after {timeout_duration:?}"),
478 });
479 tracer.trace_event(SmtpClientTraceEvent::Closed);
480 }
481 Err(make_timeout_err())
482 }
483 },
484 None => Err(make_write_err(
485 "the socket was closed in response to an earlier issue".to_string(),
486 )),
487 }
488 }
489
490 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
491 let total_timeout: Duration = commands
492 .iter()
493 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
494 .sum();
495
496 let mut lines: Vec<String> = vec![];
497 let mut all = String::new();
498 for cmd in commands {
499 let line = cmd.encode();
500 all.push_str(&line);
501 lines.push(line);
502 }
503 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
504 if self.socket.is_some() {
505 if let Some(tracer) = &self.tracer {
506 for line in lines {
509 WriteTracer::trace(tracer, &line);
510 }
511 }
512 }
513 self.write_all_with_timeout(
514 total_timeout,
515 all.as_bytes(),
516 || ClientError::TimeOutRequest {
517 duration: total_timeout,
518 commands: commands.to_vec(),
519 },
520 |error| ClientError::WriteError {
521 error,
522 commands: commands.to_vec(),
523 },
524 )
525 .await
526 }
527
528 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
529 let line = command.encode();
530 tracing::trace!("send->{}: {line}", self.hostname);
531 if self.socket.is_some() {
532 if let Some(tracer) = &self.tracer {
533 WriteTracer::trace(tracer, &line);
534 }
535 }
536
537 let timeout_duration = command.client_timeout_request(&self.timeouts);
538 self.write_all_with_timeout(
539 timeout_duration,
540 line.as_bytes(),
541 || ClientError::TimeOutRequest {
542 duration: timeout_duration,
543 commands: vec![command.clone()],
544 },
545 |error| ClientError::WriteError {
546 error,
547 commands: vec![command.clone()],
548 },
549 )
550 .await
551 }
552
553 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
554 if self.socket.is_some() {
555 if let Some(tracer) = &self.tracer {
556 BinWriteTracer::trace(tracer, data);
557 }
558 }
559 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
560 self.write_all_with_timeout(
561 timeout_duration,
562 data,
563 || ClientError::TimeOutData,
564 |error| ClientError::WriteError {
565 error,
566 commands: vec![],
567 },
568 )
569 .await
570 }
571
572 pub async fn pipeline_commands(
588 &mut self,
589 commands: Vec<Command>,
590 ) -> Vec<Result<Response, ClientError>> {
591 let mut results: Vec<Result<Response, ClientError>> = vec![];
592
593 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
594 if pipeline {
595 if let Err(err) = self.write_pipeline_request(&commands).await {
596 let err: ClientError = err;
597 results.push(Err(err.clone()));
598 while results.len() < commands.len() {
599 results.push(Err(err.clone()));
601 }
602 return results;
603 }
604
605 for cmd in &commands {
607 results.push(
608 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
609 .await,
610 );
611 }
612 return results;
613 }
614
615 for cmd in &commands {
616 if let Err(err) = self.write_command_request(cmd).await {
617 let err: ClientError = err;
618 results.push(Err(err.clone()));
619 while results.len() < commands.len() {
620 results.push(Err(err.clone()));
622 }
623 return results;
624 }
625 results.push(
628 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
629 .await,
630 );
631 }
632 results
633 }
634
635 pub async fn ehlo_lhlo(
636 &mut self,
637 ehlo_name: &str,
638 use_lmtp: bool,
639 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
640 if use_lmtp {
641 self.lhlo(ehlo_name).await
642 } else {
643 self.ehlo(ehlo_name).await
644 }
645 }
646
647 pub async fn lhlo(
648 &mut self,
649 ehlo_name: &str,
650 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
651 let response = self
652 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
653 .await?;
654 self.ehlo_common(response)
655 }
656
657 pub async fn ehlo(
658 &mut self,
659 ehlo_name: &str,
660 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
661 let response = self
662 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
663 .await?;
664 self.ehlo_common(response)
665 }
666
667 fn ehlo_common(
668 &mut self,
669 response: Response,
670 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
671 if response.code != 250 {
672 return Err(ClientError::Rejected(response));
673 }
674
675 let mut capabilities = HashMap::new();
676
677 for line in response.content.lines().skip(1) {
678 let mut fields = line.splitn(2, ' ');
679 if let Some(name) = fields.next() {
680 let param = fields.next().map(|s| s.to_string());
681 let cap = EsmtpCapability {
682 name: name.to_string(),
683 param,
684 };
685 capabilities.insert(name.to_ascii_uppercase(), cap);
686 }
687 }
688
689 self.capabilities = capabilities;
690 Ok(&self.capabilities)
691 }
692
693 pub async fn auth_plain(
694 &mut self,
695 username: &str,
696 password: Option<&str>,
697 ) -> Result<(), ClientError> {
698 let password = password.unwrap_or("");
701 let payload = format!("\x00{username}\x00{password}");
702 let payload = data_encoding::BASE64.encode(payload.as_bytes());
703
704 let response = self
705 .send_command(&Command::Auth {
706 sasl_mech: "PLAIN".to_string(),
707 initial_response: Some(payload),
708 })
709 .await?;
710
711 if response.code != 235 {
712 return Err(ClientError::Rejected(response));
713 }
714
715 Ok(())
716 }
717
718 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
724 let resp = self.send_command(&Command::StartTls).await?;
725 if resp.code != 220 {
726 return Err(ClientError::Rejected(resp));
727 }
728
729 let mut handshake_error = None;
730 let mut tls_info = TlsInformation::default();
731
732 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
733 || !options.dane_tlsa.is_empty()
734 {
735 let connector = options
736 .build_openssl_connector(&self.hostname)
737 .map_err(|error| ClientError::FailedToBuildConnector {
738 error: error.to_string(),
739 })?;
740 let ssl = connector.into_ssl(self.hostname.as_str())?;
741
742 let (stream, dup_stream) = match self.socket.take() {
743 Some(s) => {
744 let d = s.try_dup();
745 (s, d)
746 }
747 None => return Err(ClientError::NotConnected),
748 };
749
750 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
751
752 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
753 handshake_error.replace(format!("{err:#}"));
754 }
755
756 tls_info.provider_name = "openssl".to_string();
757 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
758 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
759 None => String::new(),
760 };
761 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
762
763 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
764 tls_info.subject_name = subject_name(&cert);
765 }
766 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
767 if let Some(cert) = &authority.cert {
768 tls_info.subject_name = subject_name(cert);
769 }
770 }
771
772 match (&handshake_error, dup_stream) {
773 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
774 drop(ssl_stream);
779 Box::new(dup_stream)
780 }
781 _ => Box::new(ssl_stream),
782 }
783 } else {
784 tls_info.provider_name = "rustls".to_string();
785 let connector = options.build_tls_connector().await.map_err(|error| {
786 ClientError::FailedToBuildConnector {
787 error: error.to_string(),
788 }
789 })?;
790 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
791 Ok(ip) => ServerName::IpAddress(ip.into()),
792 Err(_) => ServerName::try_from(self.hostname.clone())
793 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
794 };
795
796 match connector
797 .connect(
798 server_name,
799 match self.socket.take() {
800 Some(s) => s,
801 None => return Err(ClientError::NotConnected),
802 },
803 )
804 .into_fallible()
805 .await
806 {
807 Ok(stream) => {
808 let (_, conn) = stream.get_ref();
809 tls_info.cipher = match conn.negotiated_cipher_suite() {
810 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
811 None => String::new(),
812 };
813 tls_info.protocol_version = match conn.protocol_version() {
814 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
815 None => String::new(),
816 };
817
818 if let Some(certs) = conn.peer_certificates() {
819 let peer_cert = &certs[0];
820 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
821 tls_info.subject_name = subject_name(&cert);
822 }
823 }
824
825 Box::new(stream)
826 }
827 Err((err, stream)) => {
828 handshake_error.replace(format!("{err:#}"));
829 stream
830 }
831 }
832 };
833
834 if let Some(tracer) = &self.tracer {
835 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
836 level: Level::INFO,
837 message: match &handshake_error {
838 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
839 None => format!("STARTTLS handshake -> {tls_info:?}"),
840 },
841 });
842 }
843
844 self.socket.replace(stream);
845 Ok(match handshake_error {
846 Some(error) => TlsStatus::FailedHandshake(error),
847 None => TlsStatus::Info(tls_info),
848 })
849 }
850
851 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
852 &mut self,
853 sender: SENDER,
854 recipient: RECIP,
855 data: B,
856 ) -> Result<Response, ClientError> {
857 let recipient: ForwardPath = recipient.into();
858 let status = self
859 .send_mail_multi_recip(sender, vec![recipient], data)
860 .await?;
861 Ok(status.response)
862 }
863
864 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
865 &mut self,
866 sender: SENDER,
867 recipient_list: Vec<ForwardPath>,
868 data: B,
869 ) -> Result<BatchSendSuccess, ClientError> {
870 let sender = sender.into();
871
872 let data: &[u8] = data.as_ref();
873 let stuffed;
874
875 let data = match apply_dot_stuffing(data) {
876 Some(d) => {
877 stuffed = d;
878 &stuffed
879 }
880 None => data,
881 };
882
883 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
884 let envelope_is_8bit =
885 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
886
887 let mut mail_from_params = vec![];
888 if data_is_8bit {
889 if self.capabilities.contains_key("8BITMIME") {
890 mail_from_params.push(EsmtpParameter {
891 name: "BODY".to_string(),
892 value: Some("8BITMIME".to_string()),
893 });
894 } else if !self.ignore_8bit_checks {
895 return Err(ClientError::Rejected(Response {
896 code: 554,
897 command: None,
898 enhanced_code: Some(EnhancedStatusCode {
899 class: 5,
900 subject: 6,
901 detail: 3,
902 }),
903 content: "KumoMTA internal: DATA is 8bit, destination does \
904 not support 8BITMIME. Conversion via msg:check_fix_conformance \
905 during reception is required"
906 .to_string(),
907 }));
908 }
909 }
910
911 if envelope_is_8bit {
912 if self.capabilities.contains_key("SMTPUTF8") {
913 mail_from_params.push(EsmtpParameter {
914 name: "SMTPUTF8".to_string(),
915 value: None,
916 });
917 } else if !self.ignore_8bit_checks {
918 return Err(ClientError::Rejected(Response {
919 code: 554,
920 command: None,
921 enhanced_code: Some(EnhancedStatusCode {
922 class: 5,
923 subject: 6,
924 detail: 7,
925 }),
926 content: "KumoMTA internal: envelope is 8bit, destination does \
927 not support SMTPUTF8."
928 .to_string(),
929 }));
930 }
931 }
932
933 let mut commands = vec![];
934
935 let used_rset = self.use_rset;
945 if self.use_rset {
946 commands.push(Command::Rset);
947 }
948 commands.push(Command::MailFrom {
949 address: sender,
950 parameters: mail_from_params,
951 });
952
953 for recipient in &recipient_list {
954 commands.push(Command::RcptTo {
955 address: recipient.clone(),
956 parameters: vec![],
957 });
958 }
959 commands.push(Command::Data);
960
961 self.use_rset = true;
964
965 let mut responses = self.pipeline_commands(commands).await;
966
967 let is_err = responses.iter().any(|r| r.is_err());
975
976 if used_rset {
977 let rset_resp = responses.remove(0)?;
978 if rset_resp.code != 250 {
979 return Err(ClientError::Rejected(rset_resp));
980 }
981 }
982
983 let mail_resp = responses.remove(0)?;
984 if is_err && mail_resp.code != 250 {
985 return Err(ClientError::Rejected(mail_resp));
986 }
987
988 let mut rcpt_responses = vec![];
989 for _ in &recipient_list {
990 rcpt_responses.push(responses.remove(0)?);
991 }
992
993 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
994 return Err(ClientError::RejectedBatch(rcpt_responses));
995 }
996
997 let data_resp = responses.remove(0)?;
998 if is_err && data_resp.code != 354 {
999 return Err(ClientError::Rejected(data_resp));
1000 }
1001
1002 if data_resp.code == 354
1003 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1004 {
1005 self.write_data_with_timeout(b".\r\n").await?;
1014 let data_dot = Command::DataDot;
1015 let _ = self
1017 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1018 .await?;
1019
1020 }
1024
1025 if mail_resp.code != 250 {
1026 return Err(ClientError::Rejected(mail_resp));
1027 }
1028 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1029 if rcpt_responses.len() == 1 {
1030 return Err(ClientError::Rejected(
1031 rcpt_responses.pop().expect("have at least one"),
1032 ));
1033 }
1034 return Err(ClientError::RejectedBatch(rcpt_responses));
1035 }
1036 if data_resp.code != 354 {
1037 return Err(ClientError::Rejected(data_resp));
1038 }
1039
1040 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1041
1042 tracing::trace!("message data is {} bytes", data.len());
1043
1044 self.write_data_with_timeout(data).await?;
1045
1046 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1047
1048 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1049
1050 self.write_data_with_timeout(marker.as_bytes()).await?;
1051
1052 let data_dot = Command::DataDot;
1053 let resp = self
1054 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1055 .await?;
1056 if resp.code != 250 {
1057 return Err(ClientError::Rejected(resp));
1058 }
1059
1060 self.use_rset = self.enable_rset;
1063
1064 Ok(BatchSendSuccess {
1065 response: resp,
1066 rcpt_responses,
1067 })
1068 }
1069}
1070
1071#[derive(Debug)]
1072pub struct BatchSendSuccess {
1073 pub response: Response,
1074 pub rcpt_responses: Vec<Response>,
1075}
1076
1077#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1078pub enum TlsStatus {
1079 FailedHandshake(String),
1080 Info(TlsInformation),
1081}
1082
1083#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1084pub struct TlsInformation {
1085 pub cipher: String,
1086 pub protocol_version: String,
1087 pub subject_name: Vec<String>,
1088 pub provider_name: String,
1089}
1090
1091impl Drop for SmtpClient {
1092 fn drop(&mut self) {
1093 if let Some(tracer) = &self.tracer {
1094 if self.socket.is_some() {
1095 tracer.trace_event(SmtpClientTraceEvent::Closed);
1096 }
1097 }
1098 }
1099}
1100fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1101 if line.len() < 4 {
1102 return Err(ClientError::MalformedResponseLine(line.to_string()));
1103 }
1104
1105 match line.as_bytes()[3] {
1106 b' ' | b'-' => match line[0..3].parse::<u16>() {
1107 Ok(code) => Ok(ResponseLine {
1108 code,
1109 is_final: line.as_bytes()[3] == b' ',
1110 content: &line[4..],
1111 }),
1112 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1113 },
1114 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1115 }
1116}
1117
1118impl TlsOptions {
1119 pub fn build_openssl_connector(
1120 &self,
1121 hostname: &str,
1122 ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1123 tracing::trace!("build_openssl_connector for {hostname}");
1124 let mut builder =
1125 openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1126
1127 if let (Some(cert_data), Some(key_data)) =
1128 (&self.certificate_from_pem, &self.private_key_from_pem)
1129 {
1130 let cert = X509::from_pem(cert_data)?;
1131 builder.set_certificate(&cert)?;
1132
1133 let key = PKey::private_key_from_pem(key_data)?;
1134 builder.set_private_key(&key)?;
1135
1136 builder.check_private_key()?;
1137 }
1138
1139 if let Some(list) = &self.openssl_cipher_list {
1140 builder.set_cipher_list(list)?;
1141 }
1142
1143 if let Some(suites) = &self.openssl_cipher_suites {
1144 builder.set_ciphersuites(suites)?;
1145 }
1146
1147 if let Some(options) = &self.openssl_options {
1148 builder.clear_options(SslOptions::all());
1149 builder.set_options(*options);
1150 }
1151
1152 if self.insecure {
1153 builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1154 }
1155
1156 if !self.dane_tlsa.is_empty() {
1157 builder.dane_enable()?;
1158 builder.set_no_dane_ee_namechecks();
1159 }
1160
1161 let connector = builder.build();
1162
1163 let mut config = connector.configure()?;
1164
1165 if !self.dane_tlsa.is_empty() {
1166 config.dane_enable(hostname)?;
1167 let mut any_usable = false;
1168 for tlsa in &self.dane_tlsa {
1169 let usable = config.dane_tlsa_add(
1170 match tlsa.cert_usage() {
1171 CertUsage::PkixTa => DaneUsage::PKIX_TA,
1172 CertUsage::PkixEe => DaneUsage::PKIX_EE,
1173 CertUsage::DaneTa => DaneUsage::DANE_TA,
1174 CertUsage::DaneEe => DaneUsage::DANE_EE,
1175 CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1176 CertUsage::Private => DaneUsage::PRIV_CERT,
1177 },
1178 match tlsa.selector() {
1179 Selector::Full => DaneSelector::CERT,
1180 Selector::Spki => DaneSelector::SPKI,
1181 Selector::Unassigned(n) => DaneSelector::from_raw(n),
1182 Selector::Private => DaneSelector::PRIV_SEL,
1183 },
1184 match tlsa.matching() {
1185 Matching::Raw => DaneMatchType::FULL,
1186 Matching::Sha256 => DaneMatchType::SHA2_256,
1187 Matching::Sha512 => DaneMatchType::SHA2_512,
1188 Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1189 Matching::Private => DaneMatchType::PRIV_MATCH,
1190 },
1191 tlsa.cert_data(),
1192 )?;
1193
1194 tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1195 if usable {
1196 any_usable = true;
1197 }
1198 }
1199
1200 if !any_usable {
1201 return Err(ClientError::NoUsableDaneTlsa {
1202 hostname: hostname.to_string(),
1203 tlsa: self.dane_tlsa.clone(),
1204 });
1205 }
1206 }
1207
1208 Ok(config)
1209 }
1210}
1211
1212fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1213 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1214
1215 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1216 return None;
1217 }
1218
1219 let mut stuffed = vec![];
1220 if data.starts_with(b".") {
1221 stuffed.push(b'.');
1222 }
1223 let mut last_idx = 0;
1224 for i in LFDOT.find_iter(data) {
1225 stuffed.extend_from_slice(&data[last_idx..=i]);
1226 stuffed.push(b'.');
1227 last_idx = i + 1;
1228 }
1229 stuffed.extend_from_slice(&data[last_idx..]);
1230 Some(stuffed)
1231}
1232
1233pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1240 let mut subject_name = vec![];
1241 for entry in cert.subject_name().entries() {
1242 if let Ok(obj) = entry.object().nid().short_name() {
1243 if let Ok(name) = entry.data().as_utf8() {
1244 subject_name.push(format!("{obj}={name}"));
1245 }
1246 }
1247 }
1248 subject_name
1249}
1250
1251#[cfg(test)]
1252mod test {
1253 use super::*;
1254
1255 #[test]
1256 fn test_stuffing() {
1257 assert_eq!(apply_dot_stuffing(b"foo"), None);
1258 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1259 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1260 assert_eq!(
1261 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1262 b"foo\n..bar\n...baz\n"
1263 );
1264 }
1265
1266 #[test]
1290 fn response_line_parsing() {
1291 assert_eq!(
1292 parse_response_line("220 woot").unwrap(),
1293 ResponseLine {
1294 code: 220,
1295 is_final: true,
1296 content: "woot"
1297 }
1298 );
1299 assert_eq!(
1300 parse_response_line("220-woot").unwrap(),
1301 ResponseLine {
1302 code: 220,
1303 is_final: false,
1304 content: "woot"
1305 }
1306 );
1307
1308 assert!(matches!(
1309 parse_response_line("220_woot"),
1310 Err(ClientError::MalformedResponseLine(_))
1311 ));
1312 assert!(matches!(
1313 parse_response_line("not really"),
1314 Err(ClientError::MalformedResponseLine(_))
1315 ));
1316 }
1317
1318 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1319 let mut parsed = parse_response_line(lines[0])?;
1320 let mut b = ResponseBuilder::new(&parsed);
1321 for line in &lines[1..] {
1322 parsed = parse_response_line(line)?;
1323 b.add_line(&parsed)
1324 .map_err(ClientError::MalformedResponseLine)?;
1325 }
1326 assert!(parsed.is_final);
1327 Ok(b.build(None))
1328 }
1329
1330 #[test]
1331 fn multi_line_response() {
1332 assert_eq!(
1333 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1334 Response {
1335 code: 220,
1336 enhanced_code: None,
1337 content: "woot\nmore\ndone".to_string(),
1338 command: None
1339 }
1340 );
1341
1342 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1343 assert!(
1344 matches!(
1345 res,
1346 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1347 ),
1348 "got error {res:?}"
1349 );
1350
1351 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1352 assert!(
1353 matches!(
1354 res,
1355 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1356 ),
1357 "got error {res:?}"
1358 );
1359
1360 assert_eq!(
1361 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1362 Response {
1363 code: 220,
1364 enhanced_code: Some(EnhancedStatusCode {
1365 class: 4,
1366 subject: 1,
1367 detail: 0
1368 }),
1369 content: "woot\nmore\ndone".to_string(),
1370 command: None
1371 }
1372 );
1373
1374 assert_eq!(
1376 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1377 Response {
1378 code: 220,
1379 enhanced_code: Some(EnhancedStatusCode {
1380 class: 4,
1381 subject: 1,
1382 detail: 0
1383 }),
1384 content: "woot\nmore\ndone".to_string(),
1385 command: None
1386 }
1387 );
1388
1389 assert_eq!(
1391 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1392 Response {
1393 code: 220,
1394 enhanced_code: Some(EnhancedStatusCode {
1395 class: 4,
1396 subject: 1,
1397 detail: 0
1398 }),
1399 content: "woot\nmore\n5.5.5 done".to_string(),
1400 command: None
1401 }
1402 );
1403 }
1404
1405 #[test]
1406 fn test_extract_hostname() {
1407 assert_eq!(extract_hostname("foo"), "foo");
1408 assert_eq!(extract_hostname("foo."), "foo");
1409 assert_eq!(extract_hostname("foo:25"), "foo");
1410 assert_eq!(extract_hostname("foo.:25"), "foo");
1411 assert_eq!(extract_hostname("[foo]:25"), "foo");
1412 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1413 assert_eq!(extract_hostname("[::1]:25"), "::1");
1414 assert_eq!(extract_hostname("::1:25"), "::1");
1415 }
1416}