1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
11use openssl::x509::{X509Ref, X509};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::sync::{Arc, LazyLock};
17use std::time::Duration;
18use thiserror::Error;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, ToSocketAddrs};
21use tokio::time::timeout;
22use tokio_rustls::rustls::pki_types::ServerName;
23use tracing::Level;
24
25pub use crate::tls::TlsOptions;
26pub use {openssl, tokio_rustls};
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32 #[error("response is not UTF8")]
33 Utf8(#[from] std::string::FromUtf8Error),
34 #[error("Malformed Response: {0}")]
35 MalformedResponseLine(String),
36 #[error("Response line is too long")]
37 ResponseTooLong,
38 #[error("Not connected")]
39 NotConnected,
40 #[error("Command rejected {0:?}")]
41 Rejected(Response),
42 #[error("STARTTLS: {0} is not a valid DNS name")]
43 InvalidDnsName(String),
44 #[error("Timed Out waiting {duration:?} for response to {command:?}")]
45 TimeOutResponse {
46 command: Option<Command>,
47 duration: Duration,
48 },
49 #[error("Timed Out writing {duration:?} {commands:?}")]
50 TimeOutRequest {
51 commands: Vec<Command>,
52 duration: Duration,
53 },
54 #[error("Error {error} reading response to {command:?}")]
55 ReadError {
56 command: Option<Command>,
57 error: String,
58 partial: String,
59 },
60 #[error("Error {error} flushing send buffer")]
61 FlushError { error: String },
62 #[error("Error {error} writing {commands:?}")]
63 WriteError {
64 commands: Vec<Command>,
65 error: String,
66 },
67 #[error("Timed Out sending message payload data")]
68 TimeOutData,
69 #[error("SSL Error: {0}")]
70 SslErrorStack(#[from] openssl::error::ErrorStack),
71 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
72 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
73}
74
75impl ClientError {
76 pub fn command(&self) -> Option<String> {
78 match self {
79 Self::TimeOutResponse {
80 command: Some(command),
81 ..
82 }
83 | Self::ReadError {
84 command: Some(command),
85 ..
86 } => Some(command.encode()),
87 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
88 if !commands.is_empty() =>
89 {
90 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
91 Some(commands.join(""))
92 }
93 _ => None,
94 }
95 }
96
97 pub fn was_due_to_message(&self) -> bool {
109 match self {
110 Self::Utf8(_)
111 | Self::MalformedResponseLine(_)
112 | Self::ResponseTooLong
113 | Self::NotConnected
114 | Self::InvalidDnsName(_)
115 | Self::TimeOutResponse { .. }
116 | Self::TimeOutRequest { .. }
117 | Self::ReadError { .. }
118 | Self::FlushError { .. }
119 | Self::WriteError { .. }
120 | Self::TimeOutData
121 | Self::SslErrorStack(_)
122 | Self::NoUsableDaneTlsa { .. } => false,
123 Self::Rejected(response) => response.was_due_to_message(),
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub struct EsmtpCapability {
130 pub name: String,
131 pub param: Option<String>,
132}
133
134#[derive(Clone, Debug)]
135pub enum SmtpClientTraceEvent {
136 Closed,
137 Read(Vec<u8>),
138 Write(String),
139 Diagnostic {
140 level: tracing::Level,
141 message: String,
142 },
143}
144
145pub trait DeferredTracer {
146 fn trace(&self) -> SmtpClientTraceEvent;
147}
148
149pub trait SmtpClientTracer: std::fmt::Debug {
150 fn trace_event(&self, event: SmtpClientTraceEvent);
151 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
152}
153
154struct WriteTracer<'a> {
156 data: &'a str,
157}
158impl DeferredTracer for WriteTracer<'_> {
159 fn trace(&self) -> SmtpClientTraceEvent {
160 SmtpClientTraceEvent::Write(self.data.to_string())
161 }
162}
163impl<'a> WriteTracer<'a> {
164 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
165 tracer.lazy_trace(&Self { data });
166 }
167}
168
169struct BinWriteTracer<'a> {
170 data: &'a [u8],
171}
172impl DeferredTracer for BinWriteTracer<'_> {
173 fn trace(&self) -> SmtpClientTraceEvent {
174 let data = String::from_utf8_lossy(self.data).to_string();
175 SmtpClientTraceEvent::Write(data)
176 }
177}
178impl<'a> BinWriteTracer<'a> {
179 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
180 tracer.lazy_trace(&Self { data });
181 }
182}
183
184struct ReadTracer<'a> {
187 data: &'a [u8],
188}
189impl DeferredTracer for ReadTracer<'_> {
190 fn trace(&self) -> SmtpClientTraceEvent {
191 SmtpClientTraceEvent::Read(self.data.to_vec())
192 }
193}
194
195#[derive(Debug)]
196pub struct SmtpClient {
197 socket: Option<BoxedAsyncReadAndWrite>,
198 hostname: String,
199 capabilities: HashMap<String, EsmtpCapability>,
200 read_buffer: Vec<u8>,
201 timeouts: SmtpClientTimeouts,
202 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
203 use_rset: bool,
204 enable_rset: bool,
205 enable_pipelining: bool,
206}
207
208fn extract_hostname(hostname: &str) -> &str {
209 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
211 let hostname = if fields.len() == 2 {
212 fields[1]
213 } else {
214 hostname
215 };
216
217 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
218 &hostname[1..hostname.len() - 1]
219 } else {
220 hostname
221 };
222
223 hostname.strip_suffix(".").unwrap_or(hostname)
225}
226
227impl SmtpClient {
228 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
229 addr: A,
230 timeouts: SmtpClientTimeouts,
231 ) -> std::io::Result<Self> {
232 let stream = TcpStream::connect(addr.clone()).await?;
233 stream.set_nodelay(true)?;
235 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
236 }
237
238 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
239 stream: S,
240 peer_hostname: H,
241 timeouts: SmtpClientTimeouts,
242 ) -> Self {
243 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
244
245 Self {
246 socket: Some(Box::new(stream)),
247 hostname,
248 capabilities: HashMap::new(),
249 read_buffer: Vec::with_capacity(1024),
250 timeouts,
251 tracer: None,
252 use_rset: false,
253 enable_rset: false,
254 enable_pipelining: false,
255 }
256 }
257
258 pub fn is_connected(&self) -> bool {
259 self.socket.is_some()
260 }
261
262 pub fn set_enable_rset(&mut self, enable: bool) {
263 self.enable_rset = enable;
264 }
265
266 pub fn set_enable_pipelining(&mut self, enable: bool) {
267 self.enable_pipelining = enable;
268 }
269
270 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
271 self.tracer.replace(tracer);
272 }
273
274 pub fn timeouts(&self) -> &SmtpClientTimeouts {
275 &self.timeouts
276 }
277
278 async fn read_line(
279 &mut self,
280 timeout_duration: Duration,
281 cmd: Option<&Command>,
282 ) -> Result<String, ClientError> {
283 let mut too_long = false;
284 loop {
285 let mut iter = self.read_buffer.iter().enumerate();
286 while let Some((i, &b)) = iter.next() {
287 if b != b'\r' {
288 continue;
289 }
290 if let Some((_, b'\n')) = iter.next() {
291 if too_long {
292 self.read_buffer.drain(0..i + 2);
293
294 if let Some(tracer) = &self.tracer {
295 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
296 level: Level::ERROR,
297 message: "Response too long".to_string(),
298 });
299 }
300
301 return Err(ClientError::ResponseTooLong);
302 }
303
304 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
305 self.read_buffer.drain(0..i + 2);
306 return Ok(line?);
307 }
308 }
309 if self.read_buffer.len() > MAX_LINE_LEN {
310 self.read_buffer.clear();
311 too_long = true;
312 }
313
314 let mut data = [0u8; MAX_LINE_LEN];
316 let size = match self.socket.as_mut() {
317 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
318 Ok(Ok(size)) => size,
319 Ok(Err(err)) => {
320 self.socket.take();
321 if let Some(tracer) = &self.tracer {
322 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
323 level: Level::ERROR,
324 message: format!("Error during read: {err:#}"),
325 });
326 tracer.trace_event(SmtpClientTraceEvent::Closed);
327 }
328 return Err(ClientError::ReadError {
329 command: cmd.cloned(),
330 error: format!("{err:#}"),
331 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
332 });
333 }
334 Err(_) => {
335 self.socket.take();
336 if let Some(tracer) = &self.tracer {
337 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
338 level: Level::ERROR,
339 message: format!("Read Timeout after {timeout_duration:?}"),
340 });
341 tracer.trace_event(SmtpClientTraceEvent::Closed);
342 }
343 return Err(ClientError::TimeOutResponse {
344 command: cmd.cloned(),
345 duration: timeout_duration,
346 });
347 }
348 },
349 None => {
350 return Err(ClientError::ReadError {
351 command: cmd.cloned(),
352 error: "the socket was closed in response to an earlier issue".to_string(),
353 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354 });
355 }
356 };
357 if size == 0 {
358 self.socket.take();
359 if let Some(tracer) = &self.tracer {
360 tracer.trace_event(SmtpClientTraceEvent::Closed);
361 }
362 return Err(ClientError::ReadError {
363 command: cmd.cloned(),
364 error: "Connection closed by peer".to_string(),
365 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
366 });
367 }
368 if let Some(tracer) = &self.tracer {
369 tracer.lazy_trace(&ReadTracer {
370 data: &data[0..size],
371 });
372 }
373 self.read_buffer.extend_from_slice(&data[0..size]);
374 }
375 }
376
377 pub async fn read_response(
378 &mut self,
379 command: Option<&Command>,
380 timeout_duration: Duration,
381 ) -> Result<Response, ClientError> {
382 if let Some(sock) = self.socket.as_mut() {
383 if let Err(err) = sock.flush().await {
384 self.socket.take();
385 if let Some(tracer) = &self.tracer {
386 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
387 level: Level::ERROR,
388 message: format!("Error during flush: {err:#}"),
389 });
390 tracer.trace_event(SmtpClientTraceEvent::Closed);
391 }
392 return Err(ClientError::FlushError {
393 error: format!("{err:#}"),
394 });
395 }
396 }
397
398 let mut line = self.read_line(timeout_duration, command).await?;
399 tracing::trace!("recv<-{}: {line}", self.hostname);
400 let mut parsed = parse_response_line(&line)?;
401 let mut response_builder = ResponseBuilder::new(&parsed);
402
403 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
404 while !parsed.is_final {
405 line = self
406 .read_line(subsequent_line_timeout_duration, command)
407 .await?;
408 parsed = parse_response_line(&line)?;
409 response_builder
410 .add_line(&parsed)
411 .map_err(ClientError::MalformedResponseLine)?;
412 }
413
414 let response = response_builder.build(command.map(|cmd| cmd.encode()));
415
416 tracing::trace!("{}: {response:?}", self.hostname);
417
418 Ok(response)
419 }
420
421 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
422 self.write_command_request(command).await?;
423 self.read_response(Some(command), command.client_timeout(&self.timeouts))
424 .await
425 }
426
427 async fn write_all_with_timeout<F, G>(
431 &mut self,
432 timeout_duration: Duration,
433 bytes: &[u8],
434 make_timeout_err: F,
435 make_write_err: G,
436 ) -> Result<(), ClientError>
437 where
438 F: FnOnce() -> ClientError,
439 G: FnOnce(String) -> ClientError,
440 {
441 match self.socket.as_mut() {
442 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
443 Ok(Ok(response)) => Ok(response),
444 Ok(Err(err)) => {
445 self.socket.take();
446 if let Some(tracer) = &self.tracer {
447 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
448 level: Level::ERROR,
449 message: format!("Error during write: {err:#}"),
450 });
451 tracer.trace_event(SmtpClientTraceEvent::Closed);
452 }
453 Err(make_write_err(format!("{err:#}")))
454 }
455 Err(_) => {
456 self.socket.take();
457 if let Some(tracer) = &self.tracer {
458 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
459 level: Level::ERROR,
460 message: format!("Write Timeout after {timeout_duration:?}"),
461 });
462 tracer.trace_event(SmtpClientTraceEvent::Closed);
463 }
464 Err(make_timeout_err())
465 }
466 },
467 None => Err(make_write_err(
468 "the socket was closed in response to an earlier issue".to_string(),
469 )),
470 }
471 }
472
473 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
474 let total_timeout: Duration = commands
475 .iter()
476 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
477 .sum();
478
479 let mut lines: Vec<String> = vec![];
480 let mut all = String::new();
481 for cmd in commands {
482 let line = cmd.encode();
483 all.push_str(&line);
484 lines.push(line);
485 }
486 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
487 if self.socket.is_some() {
488 if let Some(tracer) = &self.tracer {
489 for line in lines {
492 WriteTracer::trace(tracer, &line);
493 }
494 }
495 }
496 self.write_all_with_timeout(
497 total_timeout,
498 all.as_bytes(),
499 || ClientError::TimeOutRequest {
500 duration: total_timeout,
501 commands: commands.to_vec(),
502 },
503 |error| ClientError::WriteError {
504 error,
505 commands: commands.to_vec(),
506 },
507 )
508 .await
509 }
510
511 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
512 let line = command.encode();
513 tracing::trace!("send->{}: {line}", self.hostname);
514 if self.socket.is_some() {
515 if let Some(tracer) = &self.tracer {
516 WriteTracer::trace(tracer, &line);
517 }
518 }
519
520 let timeout_duration = command.client_timeout_request(&self.timeouts);
521 self.write_all_with_timeout(
522 timeout_duration,
523 line.as_bytes(),
524 || ClientError::TimeOutRequest {
525 duration: timeout_duration,
526 commands: vec![command.clone()],
527 },
528 |error| ClientError::WriteError {
529 error,
530 commands: vec![command.clone()],
531 },
532 )
533 .await
534 }
535
536 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
537 if self.socket.is_some() {
538 if let Some(tracer) = &self.tracer {
539 BinWriteTracer::trace(tracer, data);
540 }
541 }
542 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
543 self.write_all_with_timeout(
544 timeout_duration,
545 data,
546 || ClientError::TimeOutData,
547 |error| ClientError::WriteError {
548 error,
549 commands: vec![],
550 },
551 )
552 .await
553 }
554
555 pub async fn pipeline_commands(
571 &mut self,
572 commands: Vec<Command>,
573 ) -> Vec<Result<Response, ClientError>> {
574 let mut results: Vec<Result<Response, ClientError>> = vec![];
575
576 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
577 if pipeline {
578 if let Err(err) = self.write_pipeline_request(&commands).await {
579 let err: ClientError = err;
580 results.push(Err(err.clone()));
581 while results.len() < commands.len() {
582 results.push(Err(err.clone()));
584 }
585 return results;
586 }
587
588 for cmd in &commands {
590 results.push(
591 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
592 .await,
593 );
594 }
595 return results;
596 }
597
598 for cmd in &commands {
599 if let Err(err) = self.write_command_request(cmd).await {
600 let err: ClientError = err;
601 results.push(Err(err.clone()));
602 while results.len() < commands.len() {
603 results.push(Err(err.clone()));
605 }
606 return results;
607 }
608 results.push(
611 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
612 .await,
613 );
614 }
615 results
616 }
617
618 pub async fn ehlo_lhlo(
619 &mut self,
620 ehlo_name: &str,
621 use_lmtp: bool,
622 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
623 if use_lmtp {
624 self.lhlo(ehlo_name).await
625 } else {
626 self.ehlo(ehlo_name).await
627 }
628 }
629
630 pub async fn lhlo(
631 &mut self,
632 ehlo_name: &str,
633 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
634 let response = self
635 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
636 .await?;
637 self.ehlo_common(response)
638 }
639
640 pub async fn ehlo(
641 &mut self,
642 ehlo_name: &str,
643 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
644 let response = self
645 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
646 .await?;
647 self.ehlo_common(response)
648 }
649
650 fn ehlo_common(
651 &mut self,
652 response: Response,
653 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
654 if response.code != 250 {
655 return Err(ClientError::Rejected(response));
656 }
657
658 let mut capabilities = HashMap::new();
659
660 for line in response.content.lines().skip(1) {
661 let mut fields = line.splitn(2, ' ');
662 if let Some(name) = fields.next() {
663 let param = fields.next().map(|s| s.to_string());
664 let cap = EsmtpCapability {
665 name: name.to_string(),
666 param,
667 };
668 capabilities.insert(name.to_ascii_uppercase(), cap);
669 }
670 }
671
672 self.capabilities = capabilities;
673 Ok(&self.capabilities)
674 }
675
676 pub async fn auth_plain(
677 &mut self,
678 username: &str,
679 password: Option<&str>,
680 ) -> Result<(), ClientError> {
681 let password = password.unwrap_or("");
684 let payload = format!("\x00{username}\x00{password}");
685 let payload = data_encoding::BASE64.encode(payload.as_bytes());
686
687 let response = self
688 .send_command(&Command::Auth {
689 sasl_mech: "PLAIN".to_string(),
690 initial_response: Some(payload),
691 })
692 .await?;
693
694 if response.code != 235 {
695 return Err(ClientError::Rejected(response));
696 }
697
698 Ok(())
699 }
700
701 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
707 let resp = self.send_command(&Command::StartTls).await?;
708 if resp.code != 220 {
709 return Err(ClientError::Rejected(resp));
710 }
711
712 let mut handshake_error = None;
713 let mut tls_info = TlsInformation::default();
714
715 let stream: BoxedAsyncReadAndWrite =
716 if options.prefer_openssl || !options.dane_tlsa.is_empty() {
717 let connector = options.build_openssl_connector(&self.hostname)?;
718 let ssl = connector.into_ssl(self.hostname.as_str())?;
719
720 let (stream, dup_stream) = match self.socket.take() {
721 Some(s) => {
722 let d = s.try_dup();
723 (s, d)
724 }
725 None => return Err(ClientError::NotConnected),
726 };
727
728 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
729
730 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
731 handshake_error.replace(format!("{err:#}"));
732 }
733
734 tls_info.provider_name = "openssl".to_string();
735 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
736 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
737 None => String::new(),
738 };
739 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
740
741 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
742 tls_info.subject_name = subject_name(&cert);
743 }
744 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
745 if let Some(cert) = &authority.cert {
746 tls_info.subject_name = subject_name(cert);
747 }
748 }
749
750 match (&handshake_error, dup_stream) {
751 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
752 drop(ssl_stream);
757 Box::new(dup_stream)
758 }
759 _ => Box::new(ssl_stream),
760 }
761 } else {
762 tls_info.provider_name = "rustls".to_string();
763 let connector = options.build_tls_connector().await;
764 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
765 Ok(ip) => ServerName::IpAddress(ip.into()),
766 Err(_) => ServerName::try_from(self.hostname.clone())
767 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
768 };
769
770 match connector
771 .connect(
772 server_name,
773 match self.socket.take() {
774 Some(s) => s,
775 None => return Err(ClientError::NotConnected),
776 },
777 )
778 .into_fallible()
779 .await
780 {
781 Ok(stream) => {
782 let (_, conn) = stream.get_ref();
783 tls_info.cipher = match conn.negotiated_cipher_suite() {
784 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
785 None => String::new(),
786 };
787 tls_info.protocol_version = match conn.protocol_version() {
788 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
789 None => String::new(),
790 };
791
792 if let Some(certs) = conn.peer_certificates() {
793 let peer_cert = &certs[0];
794 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
795 tls_info.subject_name = subject_name(&cert);
796 }
797 }
798
799 Box::new(stream)
800 }
801 Err((err, stream)) => {
802 handshake_error.replace(format!("{err:#}"));
803 stream
804 }
805 }
806 };
807
808 if let Some(tracer) = &self.tracer {
809 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
810 level: Level::INFO,
811 message: match &handshake_error {
812 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
813 None => format!("STARTTLS handshake -> {tls_info:?}"),
814 },
815 });
816 }
817
818 self.socket.replace(stream);
819 Ok(match handshake_error {
820 Some(error) => TlsStatus::FailedHandshake(error),
821 None => TlsStatus::Info(tls_info),
822 })
823 }
824
825 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
826 &mut self,
827 sender: SENDER,
828 recipient: RECIP,
829 data: B,
830 ) -> Result<Response, ClientError> {
831 let sender = sender.into();
832 let recipient = recipient.into();
833
834 let data: &[u8] = data.as_ref();
835 let stuffed;
836
837 let data = match apply_dot_stuffing(data) {
838 Some(d) => {
839 stuffed = d;
840 &stuffed
841 }
842 None => data,
843 };
844
845 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
846 let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
847
848 let mut mail_from_params = vec![];
849 if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
850 mail_from_params.push(EsmtpParameter {
851 name: "BODY".to_string(),
852 value: Some("8BITMIME".to_string()),
853 });
854 }
855
856 if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
857 mail_from_params.push(EsmtpParameter {
858 name: "SMTPUTF8".to_string(),
859 value: None,
860 });
861 }
862
863 let mut commands = vec![];
864
865 let used_rset = self.use_rset;
875 if self.use_rset {
876 commands.push(Command::Rset);
877 }
878 commands.push(Command::MailFrom {
879 address: sender,
880 parameters: mail_from_params,
881 });
882 commands.push(Command::RcptTo {
883 address: recipient,
884 parameters: vec![],
885 });
886 commands.push(Command::Data);
887
888 self.use_rset = true;
891
892 let mut responses = self.pipeline_commands(commands).await;
893
894 let is_err = responses.iter().any(|r| r.is_err());
902
903 if used_rset {
904 let rset_resp = responses.remove(0)?;
905 if rset_resp.code != 250 {
906 return Err(ClientError::Rejected(rset_resp));
907 }
908 }
909
910 let mail_resp = responses.remove(0)?;
911 if is_err && mail_resp.code != 250 {
912 return Err(ClientError::Rejected(mail_resp));
913 }
914
915 let rcpt_resp = responses.remove(0)?;
916 if is_err && rcpt_resp.code != 250 {
917 return Err(ClientError::Rejected(rcpt_resp));
918 }
919
920 let data_resp = responses.remove(0)?;
921 if is_err && data_resp.code != 354 {
922 return Err(ClientError::Rejected(data_resp));
923 }
924
925 if data_resp.code == 354 && (mail_resp.code != 250 || rcpt_resp.code != 250) {
926 self.write_data_with_timeout(b".\r\n").await?;
935 let data_dot = Command::DataDot;
936 let _ = self
938 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
939 .await?;
940
941 }
945
946 if mail_resp.code != 250 {
947 return Err(ClientError::Rejected(mail_resp));
948 }
949 if rcpt_resp.code != 250 {
950 return Err(ClientError::Rejected(rcpt_resp));
951 }
952 if data_resp.code != 354 {
953 return Err(ClientError::Rejected(data_resp));
954 }
955
956 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
957
958 tracing::trace!("message data is {} bytes", data.len());
959
960 self.write_data_with_timeout(data).await?;
961
962 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
963
964 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
965
966 self.write_data_with_timeout(marker.as_bytes()).await?;
967
968 let data_dot = Command::DataDot;
969 let resp = self
970 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
971 .await?;
972 if resp.code != 250 {
973 return Err(ClientError::Rejected(resp));
974 }
975
976 self.use_rset = self.enable_rset;
979
980 Ok(resp)
981 }
982}
983
984#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
985pub enum TlsStatus {
986 FailedHandshake(String),
987 Info(TlsInformation),
988}
989
990#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
991pub struct TlsInformation {
992 pub cipher: String,
993 pub protocol_version: String,
994 pub subject_name: Vec<String>,
995 pub provider_name: String,
996}
997
998impl Drop for SmtpClient {
999 fn drop(&mut self) {
1000 if let Some(tracer) = &self.tracer {
1001 if self.socket.is_some() {
1002 tracer.trace_event(SmtpClientTraceEvent::Closed);
1003 }
1004 }
1005 }
1006}
1007fn parse_response_line(line: &str) -> Result<ResponseLine, ClientError> {
1008 if line.len() < 4 {
1009 return Err(ClientError::MalformedResponseLine(line.to_string()));
1010 }
1011
1012 match line.as_bytes()[3] {
1013 b' ' | b'-' => match line[0..3].parse::<u16>() {
1014 Ok(code) => Ok(ResponseLine {
1015 code,
1016 is_final: line.as_bytes()[3] == b' ',
1017 content: &line[4..],
1018 }),
1019 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1020 },
1021 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1022 }
1023}
1024
1025impl TlsOptions {
1026 pub fn build_openssl_connector(
1027 &self,
1028 hostname: &str,
1029 ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1030 tracing::trace!("build_openssl_connector for {hostname}");
1031 let mut builder =
1032 openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1033
1034 if let Some(list) = &self.openssl_cipher_list {
1035 builder.set_cipher_list(list)?;
1036 }
1037
1038 if let Some(suites) = &self.openssl_cipher_suites {
1039 builder.set_ciphersuites(suites)?;
1040 }
1041
1042 if let Some(options) = &self.openssl_options {
1043 builder.clear_options(SslOptions::all());
1044 builder.set_options(*options);
1045 }
1046
1047 if self.insecure {
1048 builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1049 }
1050
1051 if !self.dane_tlsa.is_empty() {
1052 builder.dane_enable()?;
1053 builder.set_no_dane_ee_namechecks();
1054 }
1055
1056 let connector = builder.build();
1057
1058 let mut config = connector.configure()?;
1059
1060 if !self.dane_tlsa.is_empty() {
1061 config.dane_enable(hostname)?;
1062 let mut any_usable = false;
1063 for tlsa in &self.dane_tlsa {
1064 let usable = config.dane_tlsa_add(
1065 match tlsa.cert_usage() {
1066 CertUsage::PkixTa => DaneUsage::PKIX_TA,
1067 CertUsage::PkixEe => DaneUsage::PKIX_EE,
1068 CertUsage::DaneTa => DaneUsage::DANE_TA,
1069 CertUsage::DaneEe => DaneUsage::DANE_EE,
1070 CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1071 CertUsage::Private => DaneUsage::PRIV_CERT,
1072 },
1073 match tlsa.selector() {
1074 Selector::Full => DaneSelector::CERT,
1075 Selector::Spki => DaneSelector::SPKI,
1076 Selector::Unassigned(n) => DaneSelector::from_raw(n),
1077 Selector::Private => DaneSelector::PRIV_SEL,
1078 },
1079 match tlsa.matching() {
1080 Matching::Raw => DaneMatchType::FULL,
1081 Matching::Sha256 => DaneMatchType::SHA2_256,
1082 Matching::Sha512 => DaneMatchType::SHA2_512,
1083 Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1084 Matching::Private => DaneMatchType::PRIV_MATCH,
1085 },
1086 tlsa.cert_data(),
1087 )?;
1088
1089 tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1090 if usable {
1091 any_usable = true;
1092 }
1093 }
1094
1095 if !any_usable {
1096 return Err(ClientError::NoUsableDaneTlsa {
1097 hostname: hostname.to_string(),
1098 tlsa: self.dane_tlsa.clone(),
1099 });
1100 }
1101 }
1102
1103 Ok(config)
1104 }
1105}
1106
1107fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1108 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1109
1110 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1111 return None;
1112 }
1113
1114 let mut stuffed = vec![];
1115 if data.starts_with(b".") {
1116 stuffed.push(b'.');
1117 }
1118 let mut last_idx = 0;
1119 for i in LFDOT.find_iter(data) {
1120 stuffed.extend_from_slice(&data[last_idx..=i]);
1121 stuffed.push(b'.');
1122 last_idx = i + 1;
1123 }
1124 stuffed.extend_from_slice(&data[last_idx..]);
1125 Some(stuffed)
1126}
1127
1128fn subject_name(cert: &X509Ref) -> Vec<String> {
1135 let mut subject_name = vec![];
1136 for entry in cert.subject_name().entries() {
1137 if let Ok(obj) = entry.object().nid().short_name() {
1138 if let Ok(name) = entry.data().as_utf8() {
1139 subject_name.push(format!("{obj}={name}"));
1140 }
1141 }
1142 }
1143 subject_name
1144}
1145
1146#[cfg(test)]
1147mod test {
1148 use super::*;
1149
1150 #[test]
1151 fn test_stuffing() {
1152 assert_eq!(apply_dot_stuffing(b"foo"), None);
1153 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1154 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1155 assert_eq!(
1156 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1157 b"foo\n..bar\n...baz\n"
1158 );
1159 }
1160
1161 #[test]
1185 fn response_line_parsing() {
1186 assert_eq!(
1187 parse_response_line("220 woot").unwrap(),
1188 ResponseLine {
1189 code: 220,
1190 is_final: true,
1191 content: "woot"
1192 }
1193 );
1194 assert_eq!(
1195 parse_response_line("220-woot").unwrap(),
1196 ResponseLine {
1197 code: 220,
1198 is_final: false,
1199 content: "woot"
1200 }
1201 );
1202
1203 assert!(matches!(
1204 parse_response_line("220_woot"),
1205 Err(ClientError::MalformedResponseLine(_))
1206 ));
1207 assert!(matches!(
1208 parse_response_line("not really"),
1209 Err(ClientError::MalformedResponseLine(_))
1210 ));
1211 }
1212
1213 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1214 let mut parsed = parse_response_line(lines[0])?;
1215 let mut b = ResponseBuilder::new(&parsed);
1216 for line in &lines[1..] {
1217 parsed = parse_response_line(line)?;
1218 b.add_line(&parsed)
1219 .map_err(ClientError::MalformedResponseLine)?;
1220 }
1221 assert!(parsed.is_final);
1222 Ok(b.build(None))
1223 }
1224
1225 #[test]
1226 fn multi_line_response() {
1227 assert_eq!(
1228 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1229 Response {
1230 code: 220,
1231 enhanced_code: None,
1232 content: "woot\nmore\ndone".to_string(),
1233 command: None
1234 }
1235 );
1236
1237 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1238 assert!(
1239 matches!(
1240 res,
1241 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1242 ),
1243 "got error {res:?}"
1244 );
1245
1246 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1247 assert!(
1248 matches!(
1249 res,
1250 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1251 ),
1252 "got error {res:?}"
1253 );
1254
1255 assert_eq!(
1256 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1257 Response {
1258 code: 220,
1259 enhanced_code: Some(EnhancedStatusCode {
1260 class: 4,
1261 subject: 1,
1262 detail: 0
1263 }),
1264 content: "woot\nmore\ndone".to_string(),
1265 command: None
1266 }
1267 );
1268
1269 assert_eq!(
1271 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1272 Response {
1273 code: 220,
1274 enhanced_code: Some(EnhancedStatusCode {
1275 class: 4,
1276 subject: 1,
1277 detail: 0
1278 }),
1279 content: "woot\nmore\ndone".to_string(),
1280 command: None
1281 }
1282 );
1283
1284 assert_eq!(
1286 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1287 Response {
1288 code: 220,
1289 enhanced_code: Some(EnhancedStatusCode {
1290 class: 4,
1291 subject: 1,
1292 detail: 0
1293 }),
1294 content: "woot\nmore\n5.5.5 done".to_string(),
1295 command: None
1296 }
1297 );
1298 }
1299
1300 #[test]
1301 fn test_extract_hostname() {
1302 assert_eq!(extract_hostname("foo"), "foo");
1303 assert_eq!(extract_hostname("foo."), "foo");
1304 assert_eq!(extract_hostname("foo:25"), "foo");
1305 assert_eq!(extract_hostname("foo.:25"), "foo");
1306 assert_eq!(extract_hostname("[foo]:25"), "foo");
1307 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1308 assert_eq!(extract_hostname("[::1]:25"), "::1");
1309 assert_eq!(extract_hostname("::1:25"), "::1");
1310 }
1311}