rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::pkey::PKey;
11use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
12use openssl::x509::{X509Ref, X509};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::IpAddr;
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock};
18use std::time::Duration;
19use thiserror::Error;
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpStream, ToSocketAddrs};
22use tokio::time::timeout;
23use tokio_rustls::rustls::pki_types::ServerName;
24use tracing::Level;
25
26pub use crate::tls::TlsOptions;
27pub use {openssl, tokio_rustls};
28
29const MAX_LINE_LEN: usize = 4096;
30
31#[derive(Error, Debug, Clone)]
32pub enum ClientError {
33    #[error("response is not UTF8")]
34    Utf8(#[from] std::string::FromUtf8Error),
35    #[error("Malformed Response: {0}")]
36    MalformedResponseLine(String),
37    #[error("Response line is too long")]
38    ResponseTooLong,
39    #[error("Not connected")]
40    NotConnected,
41    #[error("Command rejected {0:?}")]
42    Rejected(Response),
43    #[error("STARTTLS: {0} is not a valid DNS name")]
44    InvalidDnsName(String),
45    #[error("Invalid client certificate configured: {error:?}")]
46    FailedToBuildConnector { error: String },
47    #[error("Timed Out waiting {duration:?} for response to {command:?}")]
48    TimeOutResponse {
49        command: Option<Command>,
50        duration: Duration,
51    },
52    #[error("Timed Out writing {duration:?} {commands:?}")]
53    TimeOutRequest {
54        commands: Vec<Command>,
55        duration: Duration,
56    },
57    #[error("Error {error} reading response to {command:?}")]
58    ReadError {
59        command: Option<Command>,
60        error: String,
61        partial: String,
62    },
63    #[error("Error {error} flushing send buffer")]
64    FlushError { error: String },
65    #[error("Error {error} writing {commands:?}")]
66    WriteError {
67        commands: Vec<Command>,
68        error: String,
69    },
70    #[error("Timed Out sending message payload data")]
71    TimeOutData,
72    #[error("SSL Error: {0}")]
73    SslErrorStack(#[from] openssl::error::ErrorStack),
74    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79    /// Returns the command(s) string suitable for passing into a Response
80    pub fn command(&self) -> Option<String> {
81        match self {
82            Self::TimeOutResponse {
83                command: Some(command),
84                ..
85            }
86            | Self::ReadError {
87                command: Some(command),
88                ..
89            } => Some(command.encode()),
90            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91                if !commands.is_empty() =>
92            {
93                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
94                Some(commands.join(""))
95            }
96            _ => None,
97        }
98    }
99
100    /// If the error contents were likely caused by something
101    /// about the mostly recently attempted message, rather than
102    /// a transport issue, or a carry-over from a prior message
103    /// (eg: previous message was rejected and destination chose
104    /// to drop the connection, which we detect later in RSET
105    /// on the next message), then we return true.
106    /// The expectation is that the caller will transiently
107    /// fail the message for later retry.
108    /// If we return false then the caller might decide to
109    /// try that same message again more immediately on
110    /// a separate connection
111    pub fn was_due_to_message(&self) -> bool {
112        match self {
113            Self::Utf8(_)
114            | Self::MalformedResponseLine(_)
115            | Self::ResponseTooLong
116            | Self::NotConnected
117            | Self::InvalidDnsName(_)
118            | Self::FailedToBuildConnector { .. }
119            | Self::TimeOutResponse { .. }
120            | Self::TimeOutRequest { .. }
121            | Self::ReadError { .. }
122            | Self::FlushError { .. }
123            | Self::WriteError { .. }
124            | Self::TimeOutData
125            | Self::SslErrorStack(_)
126            | Self::NoUsableDaneTlsa { .. } => false,
127            Self::Rejected(response) => response.was_due_to_message(),
128        }
129    }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct EsmtpCapability {
134    pub name: String,
135    pub param: Option<String>,
136}
137
138#[derive(Clone, Debug)]
139pub enum SmtpClientTraceEvent {
140    Closed,
141    Read(Vec<u8>),
142    Write(String),
143    Diagnostic {
144        level: tracing::Level,
145        message: String,
146    },
147}
148
149pub trait DeferredTracer {
150    fn trace(&self) -> SmtpClientTraceEvent;
151}
152
153pub trait SmtpClientTracer: std::fmt::Debug {
154    fn trace_event(&self, event: SmtpClientTraceEvent);
155    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
156}
157
158// helper to avoid making a second copy of every write buffer
159struct WriteTracer<'a> {
160    data: &'a str,
161}
162impl DeferredTracer for WriteTracer<'_> {
163    fn trace(&self) -> SmtpClientTraceEvent {
164        SmtpClientTraceEvent::Write(self.data.to_string())
165    }
166}
167impl<'a> WriteTracer<'a> {
168    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
169        tracer.lazy_trace(&Self { data });
170    }
171}
172
173struct BinWriteTracer<'a> {
174    data: &'a [u8],
175}
176impl DeferredTracer for BinWriteTracer<'_> {
177    fn trace(&self) -> SmtpClientTraceEvent {
178        let data = String::from_utf8_lossy(self.data).to_string();
179        SmtpClientTraceEvent::Write(data)
180    }
181}
182impl<'a> BinWriteTracer<'a> {
183    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
184        tracer.lazy_trace(&Self { data });
185    }
186}
187
188// A little bit of gymnastics to avoid making a second
189// copy of every read buffer
190struct ReadTracer<'a> {
191    data: &'a [u8],
192}
193impl DeferredTracer for ReadTracer<'_> {
194    fn trace(&self) -> SmtpClientTraceEvent {
195        SmtpClientTraceEvent::Read(self.data.to_vec())
196    }
197}
198
199#[derive(Debug)]
200pub struct SmtpClient {
201    socket: Option<BoxedAsyncReadAndWrite>,
202    hostname: String,
203    capabilities: HashMap<String, EsmtpCapability>,
204    read_buffer: Vec<u8>,
205    timeouts: SmtpClientTimeouts,
206    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
207    use_rset: bool,
208    enable_rset: bool,
209    enable_pipelining: bool,
210}
211
212fn extract_hostname(hostname: &str) -> &str {
213    // Just the hostname, without any :port
214    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
215    let hostname = if fields.len() == 2 {
216        fields[1]
217    } else {
218        hostname
219    };
220
221    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
222        &hostname[1..hostname.len() - 1]
223    } else {
224        hostname
225    };
226
227    // Remove any trailing FQDN dot
228    hostname.strip_suffix(".").unwrap_or(hostname)
229}
230
231impl SmtpClient {
232    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
233        addr: A,
234        timeouts: SmtpClientTimeouts,
235    ) -> std::io::Result<Self> {
236        let stream = TcpStream::connect(addr.clone()).await?;
237        // No need for Nagle with SMTP request/response
238        stream.set_nodelay(true)?;
239        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
240    }
241
242    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
243        stream: S,
244        peer_hostname: H,
245        timeouts: SmtpClientTimeouts,
246    ) -> Self {
247        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
248
249        Self {
250            socket: Some(Box::new(stream)),
251            hostname,
252            capabilities: HashMap::new(),
253            read_buffer: Vec::with_capacity(1024),
254            timeouts,
255            tracer: None,
256            use_rset: false,
257            enable_rset: false,
258            enable_pipelining: false,
259        }
260    }
261
262    pub fn is_connected(&self) -> bool {
263        self.socket.is_some()
264    }
265
266    pub fn set_enable_rset(&mut self, enable: bool) {
267        self.enable_rset = enable;
268    }
269
270    pub fn set_enable_pipelining(&mut self, enable: bool) {
271        self.enable_pipelining = enable;
272    }
273
274    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
275        self.tracer.replace(tracer);
276    }
277
278    pub fn timeouts(&self) -> &SmtpClientTimeouts {
279        &self.timeouts
280    }
281
282    async fn read_line(
283        &mut self,
284        timeout_duration: Duration,
285        cmd: Option<&Command>,
286    ) -> Result<String, ClientError> {
287        let mut too_long = false;
288        loop {
289            let mut iter = self.read_buffer.iter().enumerate();
290            while let Some((i, &b)) = iter.next() {
291                if b != b'\r' {
292                    continue;
293                }
294                if let Some((_, b'\n')) = iter.next() {
295                    if too_long {
296                        self.read_buffer.drain(0..i + 2);
297
298                        if let Some(tracer) = &self.tracer {
299                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
300                                level: Level::ERROR,
301                                message: "Response too long".to_string(),
302                            });
303                        }
304
305                        return Err(ClientError::ResponseTooLong);
306                    }
307
308                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
309                    self.read_buffer.drain(0..i + 2);
310                    return Ok(line?);
311                }
312            }
313            if self.read_buffer.len() > MAX_LINE_LEN {
314                self.read_buffer.clear();
315                too_long = true;
316            }
317
318            // Didn't find a complete line, fill up the rest of the buffer
319            let mut data = [0u8; MAX_LINE_LEN];
320            let size = match self.socket.as_mut() {
321                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
322                    Ok(Ok(size)) => size,
323                    Ok(Err(err)) => {
324                        self.socket.take();
325                        if let Some(tracer) = &self.tracer {
326                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
327                                level: Level::ERROR,
328                                message: format!("Error during read: {err:#}"),
329                            });
330                            tracer.trace_event(SmtpClientTraceEvent::Closed);
331                        }
332                        return Err(ClientError::ReadError {
333                            command: cmd.cloned(),
334                            error: format!("{err:#}"),
335                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
336                        });
337                    }
338                    Err(_) => {
339                        self.socket.take();
340                        if let Some(tracer) = &self.tracer {
341                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
342                                level: Level::ERROR,
343                                message: format!("Read Timeout after {timeout_duration:?}"),
344                            });
345                            tracer.trace_event(SmtpClientTraceEvent::Closed);
346                        }
347                        return Err(ClientError::TimeOutResponse {
348                            command: cmd.cloned(),
349                            duration: timeout_duration,
350                        });
351                    }
352                },
353                None => {
354                    return Err(ClientError::ReadError {
355                        command: cmd.cloned(),
356                        error: "the socket was closed in response to an earlier issue".to_string(),
357                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
358                    });
359                }
360            };
361            if size == 0 {
362                self.socket.take();
363                if let Some(tracer) = &self.tracer {
364                    tracer.trace_event(SmtpClientTraceEvent::Closed);
365                }
366                return Err(ClientError::ReadError {
367                    command: cmd.cloned(),
368                    error: "Connection closed by peer".to_string(),
369                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
370                });
371            }
372            if let Some(tracer) = &self.tracer {
373                tracer.lazy_trace(&ReadTracer {
374                    data: &data[0..size],
375                });
376            }
377            self.read_buffer.extend_from_slice(&data[0..size]);
378        }
379    }
380
381    pub async fn read_response(
382        &mut self,
383        command: Option<&Command>,
384        timeout_duration: Duration,
385    ) -> Result<Response, ClientError> {
386        if let Some(sock) = self.socket.as_mut() {
387            if let Err(err) = sock.flush().await {
388                self.socket.take();
389                if let Some(tracer) = &self.tracer {
390                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
391                        level: Level::ERROR,
392                        message: format!("Error during flush: {err:#}"),
393                    });
394                    tracer.trace_event(SmtpClientTraceEvent::Closed);
395                }
396                return Err(ClientError::FlushError {
397                    error: format!("{err:#}"),
398                });
399            }
400        }
401
402        let mut line = self.read_line(timeout_duration, command).await?;
403        tracing::trace!("recv<-{}: {line}", self.hostname);
404        let mut parsed = parse_response_line(&line)?;
405        let mut response_builder = ResponseBuilder::new(&parsed);
406
407        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
408        while !parsed.is_final {
409            line = self
410                .read_line(subsequent_line_timeout_duration, command)
411                .await?;
412            parsed = parse_response_line(&line)?;
413            response_builder
414                .add_line(&parsed)
415                .map_err(ClientError::MalformedResponseLine)?;
416        }
417
418        let response = response_builder.build(command.map(|cmd| cmd.encode()));
419
420        tracing::trace!("{}: {response:?}", self.hostname);
421
422        Ok(response)
423    }
424
425    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
426        self.write_command_request(command).await?;
427        self.read_response(Some(command), command.client_timeout(&self.timeouts))
428            .await
429    }
430
431    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
432    /// a Close event to the tracer if a timeout or IO error occurs.
433    /// If an error or timeout, occurs ensures that the socket will not be reused.
434    async fn write_all_with_timeout<F, G>(
435        &mut self,
436        timeout_duration: Duration,
437        bytes: &[u8],
438        make_timeout_err: F,
439        make_write_err: G,
440    ) -> Result<(), ClientError>
441    where
442        F: FnOnce() -> ClientError,
443        G: FnOnce(String) -> ClientError,
444    {
445        match self.socket.as_mut() {
446            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
447                Ok(Ok(response)) => Ok(response),
448                Ok(Err(err)) => {
449                    self.socket.take();
450                    if let Some(tracer) = &self.tracer {
451                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
452                            level: Level::ERROR,
453                            message: format!("Error during write: {err:#}"),
454                        });
455                        tracer.trace_event(SmtpClientTraceEvent::Closed);
456                    }
457                    Err(make_write_err(format!("{err:#}")))
458                }
459                Err(_) => {
460                    self.socket.take();
461                    if let Some(tracer) = &self.tracer {
462                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
463                            level: Level::ERROR,
464                            message: format!("Write Timeout after {timeout_duration:?}"),
465                        });
466                        tracer.trace_event(SmtpClientTraceEvent::Closed);
467                    }
468                    Err(make_timeout_err())
469                }
470            },
471            None => Err(make_write_err(
472                "the socket was closed in response to an earlier issue".to_string(),
473            )),
474        }
475    }
476
477    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
478        let total_timeout: Duration = commands
479            .iter()
480            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
481            .sum();
482
483        let mut lines: Vec<String> = vec![];
484        let mut all = String::new();
485        for cmd in commands {
486            let line = cmd.encode();
487            all.push_str(&line);
488            lines.push(line);
489        }
490        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
491        if self.socket.is_some() {
492            if let Some(tracer) = &self.tracer {
493                // Send the lines individually to the tracer, so that we
494                // don't break --terse mode
495                for line in lines {
496                    WriteTracer::trace(tracer, &line);
497                }
498            }
499        }
500        self.write_all_with_timeout(
501            total_timeout,
502            all.as_bytes(),
503            || ClientError::TimeOutRequest {
504                duration: total_timeout,
505                commands: commands.to_vec(),
506            },
507            |error| ClientError::WriteError {
508                error,
509                commands: commands.to_vec(),
510            },
511        )
512        .await
513    }
514
515    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
516        let line = command.encode();
517        tracing::trace!("send->{}: {line}", self.hostname);
518        if self.socket.is_some() {
519            if let Some(tracer) = &self.tracer {
520                WriteTracer::trace(tracer, &line);
521            }
522        }
523
524        let timeout_duration = command.client_timeout_request(&self.timeouts);
525        self.write_all_with_timeout(
526            timeout_duration,
527            line.as_bytes(),
528            || ClientError::TimeOutRequest {
529                duration: timeout_duration,
530                commands: vec![command.clone()],
531            },
532            |error| ClientError::WriteError {
533                error,
534                commands: vec![command.clone()],
535            },
536        )
537        .await
538    }
539
540    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
541        if self.socket.is_some() {
542            if let Some(tracer) = &self.tracer {
543                BinWriteTracer::trace(tracer, data);
544            }
545        }
546        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
547        self.write_all_with_timeout(
548            timeout_duration,
549            data,
550            || ClientError::TimeOutData,
551            |error| ClientError::WriteError {
552                error,
553                commands: vec![],
554            },
555        )
556        .await
557    }
558
559    /// Issue a series of commands, and return the responses to
560    /// those commands.
561    ///
562    /// If the server advertised the RFC 2920 PIPELINING extension,
563    /// the commands are written one after the other before waiting
564    /// to read any data, resulting in lower overall latency due
565    /// to round-trip-times.
566    ///
567    /// If PIPELINING is not available, each command is written
568    /// and the response read before attempting to write the next
569    /// command.
570    ///
571    /// The number of returned responses may be smaller than the
572    /// number of requested commands if there is an issue with
573    /// the network connection.
574    pub async fn pipeline_commands(
575        &mut self,
576        commands: Vec<Command>,
577    ) -> Vec<Result<Response, ClientError>> {
578        let mut results: Vec<Result<Response, ClientError>> = vec![];
579
580        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
581        if pipeline {
582            if let Err(err) = self.write_pipeline_request(&commands).await {
583                let err: ClientError = err;
584                results.push(Err(err.clone()));
585                while results.len() < commands.len() {
586                    // Synthesize failures for the remaining commands
587                    results.push(Err(err.clone()));
588                }
589                return results;
590            }
591
592            // Now read the responses effectively in a batch
593            for cmd in &commands {
594                results.push(
595                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
596                        .await,
597                );
598            }
599            return results;
600        }
601
602        for cmd in &commands {
603            if let Err(err) = self.write_command_request(cmd).await {
604                let err: ClientError = err;
605                results.push(Err(err.clone()));
606                while results.len() < commands.len() {
607                    // Synthesize failures for the remaining commands
608                    results.push(Err(err.clone()));
609                }
610                return results;
611            }
612            // Immediately request the response if the server
613            // doesn't support pipelining
614            results.push(
615                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
616                    .await,
617            );
618        }
619        results
620    }
621
622    pub async fn ehlo_lhlo(
623        &mut self,
624        ehlo_name: &str,
625        use_lmtp: bool,
626    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
627        if use_lmtp {
628            self.lhlo(ehlo_name).await
629        } else {
630            self.ehlo(ehlo_name).await
631        }
632    }
633
634    pub async fn lhlo(
635        &mut self,
636        ehlo_name: &str,
637    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
638        let response = self
639            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
640            .await?;
641        self.ehlo_common(response)
642    }
643
644    pub async fn ehlo(
645        &mut self,
646        ehlo_name: &str,
647    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
648        let response = self
649            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
650            .await?;
651        self.ehlo_common(response)
652    }
653
654    fn ehlo_common(
655        &mut self,
656        response: Response,
657    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
658        if response.code != 250 {
659            return Err(ClientError::Rejected(response));
660        }
661
662        let mut capabilities = HashMap::new();
663
664        for line in response.content.lines().skip(1) {
665            let mut fields = line.splitn(2, ' ');
666            if let Some(name) = fields.next() {
667                let param = fields.next().map(|s| s.to_string());
668                let cap = EsmtpCapability {
669                    name: name.to_string(),
670                    param,
671                };
672                capabilities.insert(name.to_ascii_uppercase(), cap);
673            }
674        }
675
676        self.capabilities = capabilities;
677        Ok(&self.capabilities)
678    }
679
680    pub async fn auth_plain(
681        &mut self,
682        username: &str,
683        password: Option<&str>,
684    ) -> Result<(), ClientError> {
685        // RFC 4616 says that the format is:
686        // [authzid] NUL authcid NUL passwd
687        let password = password.unwrap_or("");
688        let payload = format!("\x00{username}\x00{password}");
689        let payload = data_encoding::BASE64.encode(payload.as_bytes());
690
691        let response = self
692            .send_command(&Command::Auth {
693                sasl_mech: "PLAIN".to_string(),
694                initial_response: Some(payload),
695            })
696            .await?;
697
698        if response.code != 235 {
699            return Err(ClientError::Rejected(response));
700        }
701
702        Ok(())
703    }
704
705    /// Attempt TLS handshake.
706    /// Returns Err for IO errors.
707    /// On completion, return an option that will be:
708    /// * Some(handshake_error) - if the handshake failed
709    /// * None - if the handshake succeeded
710    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
711        let resp = self.send_command(&Command::StartTls).await?;
712        if resp.code != 220 {
713            return Err(ClientError::Rejected(resp));
714        }
715
716        let mut handshake_error = None;
717        let mut tls_info = TlsInformation::default();
718
719        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
720            || !options.dane_tlsa.is_empty()
721        {
722            let connector = options
723                .build_openssl_connector(&self.hostname)
724                .map_err(|error| ClientError::FailedToBuildConnector {
725                    error: error.to_string(),
726                })?;
727            let ssl = connector.into_ssl(self.hostname.as_str())?;
728
729            let (stream, dup_stream) = match self.socket.take() {
730                Some(s) => {
731                    let d = s.try_dup();
732                    (s, d)
733                }
734                None => return Err(ClientError::NotConnected),
735            };
736
737            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
738
739            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
740                handshake_error.replace(format!("{err:#}"));
741            }
742
743            tls_info.provider_name = "openssl".to_string();
744            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
745                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
746                None => String::new(),
747            };
748            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
749
750            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
751                tls_info.subject_name = subject_name(&cert);
752            }
753            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
754                if let Some(cert) = &authority.cert {
755                    tls_info.subject_name = subject_name(cert);
756                }
757            }
758
759            match (&handshake_error, dup_stream) {
760                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
761                    // Try falling back to clear text on the duplicate stream.
762                    // This is imperfect: in a failed validation scenario we will
763                    // end up trying to read binary data as a string and get a UTF-8
764                    // error if the peer thinks the session is encrypted.
765                    drop(ssl_stream);
766                    Box::new(dup_stream)
767                }
768                _ => Box::new(ssl_stream),
769            }
770        } else {
771            tls_info.provider_name = "rustls".to_string();
772            let connector = options.build_tls_connector().await.map_err(|error| {
773                ClientError::FailedToBuildConnector {
774                    error: error.to_string(),
775                }
776            })?;
777            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
778                Ok(ip) => ServerName::IpAddress(ip.into()),
779                Err(_) => ServerName::try_from(self.hostname.clone())
780                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
781            };
782
783            match connector
784                .connect(
785                    server_name,
786                    match self.socket.take() {
787                        Some(s) => s,
788                        None => return Err(ClientError::NotConnected),
789                    },
790                )
791                .into_fallible()
792                .await
793            {
794                Ok(stream) => {
795                    let (_, conn) = stream.get_ref();
796                    tls_info.cipher = match conn.negotiated_cipher_suite() {
797                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
798                        None => String::new(),
799                    };
800                    tls_info.protocol_version = match conn.protocol_version() {
801                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
802                        None => String::new(),
803                    };
804
805                    if let Some(certs) = conn.peer_certificates() {
806                        let peer_cert = &certs[0];
807                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
808                            tls_info.subject_name = subject_name(&cert);
809                        }
810                    }
811
812                    Box::new(stream)
813                }
814                Err((err, stream)) => {
815                    handshake_error.replace(format!("{err:#}"));
816                    stream
817                }
818            }
819        };
820
821        if let Some(tracer) = &self.tracer {
822            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
823                level: Level::INFO,
824                message: match &handshake_error {
825                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
826                    None => format!("STARTTLS handshake -> {tls_info:?}"),
827                },
828            });
829        }
830
831        self.socket.replace(stream);
832        Ok(match handshake_error {
833            Some(error) => TlsStatus::FailedHandshake(error),
834            None => TlsStatus::Info(tls_info),
835        })
836    }
837
838    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
839        &mut self,
840        sender: SENDER,
841        recipient: RECIP,
842        data: B,
843    ) -> Result<Response, ClientError> {
844        let sender = sender.into();
845        let recipient = recipient.into();
846
847        let data: &[u8] = data.as_ref();
848        let stuffed;
849
850        let data = match apply_dot_stuffing(data) {
851            Some(d) => {
852                stuffed = d;
853                &stuffed
854            }
855            None => data,
856        };
857
858        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
859        let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
860
861        let mut mail_from_params = vec![];
862        if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
863            mail_from_params.push(EsmtpParameter {
864                name: "BODY".to_string(),
865                value: Some("8BITMIME".to_string()),
866            });
867        }
868
869        if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
870            mail_from_params.push(EsmtpParameter {
871                name: "SMTPUTF8".to_string(),
872                value: None,
873            });
874        }
875
876        let mut commands = vec![];
877
878        // We want to avoid using RSET for the first message we send on
879        // a given connection, because postfix can run in a mode where
880        // it will not tolerate RSET because it considers it to be a "junk"
881        // command, and rejecting junk commands will cut down on its load
882        // when it is under stress; it is used as a load shedding approach.
883        // If we always RSET then we will never deliver to a site that is
884        // configured that way. If we take care to RSET only for subsequent
885        // sends, then we should get one message per connection through
886        // without being unfairly penalized for defensively RSETing.
887        let used_rset = self.use_rset;
888        if self.use_rset {
889            commands.push(Command::Rset);
890        }
891        commands.push(Command::MailFrom {
892            address: sender,
893            parameters: mail_from_params,
894        });
895        commands.push(Command::RcptTo {
896            address: recipient,
897            parameters: vec![],
898        });
899        commands.push(Command::Data);
900
901        // Assume that something might break below: if it does, we want
902        // to ensure that we RSET the connection on the next go around.
903        self.use_rset = true;
904
905        let mut responses = self.pipeline_commands(commands).await;
906
907        // This is a little awkward. We want to handle the RFC 2090 3.1 case
908        // below, which requires deferring checking the actual response codes
909        // until later, but we must also handle the case where we had a hard
910        // transport error partway through pipelining.
911        // So we set a flag for that case and will then "eagerly", wrt. the
912        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
913        // can propagate the correct error disposition up to the caller.
914        let is_err = responses.iter().any(|r| r.is_err());
915
916        if used_rset {
917            let rset_resp = responses.remove(0)?;
918            if rset_resp.code != 250 {
919                return Err(ClientError::Rejected(rset_resp));
920            }
921        }
922
923        let mail_resp = responses.remove(0)?;
924        if is_err && mail_resp.code != 250 {
925            return Err(ClientError::Rejected(mail_resp));
926        }
927
928        let rcpt_resp = responses.remove(0)?;
929        if is_err && rcpt_resp.code != 250 {
930            return Err(ClientError::Rejected(rcpt_resp));
931        }
932
933        let data_resp = responses.remove(0)?;
934        if is_err && data_resp.code != 354 {
935            return Err(ClientError::Rejected(data_resp));
936        }
937
938        if data_resp.code == 354 && (mail_resp.code != 250 || rcpt_resp.code != 250) {
939            // RFC 2920 3.1:
940            // the client cannot assume that the DATA command will be rejected
941            // just because none of the RCPT TO commands worked.  If the DATA
942            // command was properly rejected the client SMTP can just issue
943            // RSET, but if the DATA command was accepted the client SMTP
944            // should send a single dot.
945
946            // Send dummy data
947            self.write_data_with_timeout(b".\r\n").await?;
948            let data_dot = Command::DataDot;
949            // wait for its response
950            let _ = self
951                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
952                .await?;
953
954            // Continue below: we will match one of the failure cases and
955            // return a ClientError::Rejected from one of the earlier
956            // commands
957        }
958
959        if mail_resp.code != 250 {
960            return Err(ClientError::Rejected(mail_resp));
961        }
962        if rcpt_resp.code != 250 {
963            return Err(ClientError::Rejected(rcpt_resp));
964        }
965        if data_resp.code != 354 {
966            return Err(ClientError::Rejected(data_resp));
967        }
968
969        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
970
971        tracing::trace!("message data is {} bytes", data.len());
972
973        self.write_data_with_timeout(data).await?;
974
975        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
976
977        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
978
979        self.write_data_with_timeout(marker.as_bytes()).await?;
980
981        let data_dot = Command::DataDot;
982        let resp = self
983            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
984            .await?;
985        if resp.code != 250 {
986            return Err(ClientError::Rejected(resp));
987        }
988
989        // If everything went well, respect the user preference for speculatively
990        // issuing an RSET next time around
991        self.use_rset = self.enable_rset;
992
993        Ok(resp)
994    }
995}
996
997#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
998pub enum TlsStatus {
999    FailedHandshake(String),
1000    Info(TlsInformation),
1001}
1002
1003#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1004pub struct TlsInformation {
1005    pub cipher: String,
1006    pub protocol_version: String,
1007    pub subject_name: Vec<String>,
1008    pub provider_name: String,
1009}
1010
1011impl Drop for SmtpClient {
1012    fn drop(&mut self) {
1013        if let Some(tracer) = &self.tracer {
1014            if self.socket.is_some() {
1015                tracer.trace_event(SmtpClientTraceEvent::Closed);
1016            }
1017        }
1018    }
1019}
1020fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1021    if line.len() < 4 {
1022        return Err(ClientError::MalformedResponseLine(line.to_string()));
1023    }
1024
1025    match line.as_bytes()[3] {
1026        b' ' | b'-' => match line[0..3].parse::<u16>() {
1027            Ok(code) => Ok(ResponseLine {
1028                code,
1029                is_final: line.as_bytes()[3] == b' ',
1030                content: &line[4..],
1031            }),
1032            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1033        },
1034        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1035    }
1036}
1037
1038impl TlsOptions {
1039    pub fn build_openssl_connector(
1040        &self,
1041        hostname: &str,
1042    ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1043        tracing::trace!("build_openssl_connector for {hostname}");
1044        let mut builder =
1045            openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1046
1047        if let (Some(cert_data), Some(key_data)) =
1048            (&self.certificate_from_pem, &self.private_key_from_pem)
1049        {
1050            let cert = X509::from_pem(cert_data)?;
1051            builder.set_certificate(&cert)?;
1052
1053            let key = PKey::private_key_from_pem(key_data)?;
1054            builder.set_private_key(&key)?;
1055
1056            builder.check_private_key()?;
1057        }
1058
1059        if let Some(list) = &self.openssl_cipher_list {
1060            builder.set_cipher_list(list)?;
1061        }
1062
1063        if let Some(suites) = &self.openssl_cipher_suites {
1064            builder.set_ciphersuites(suites)?;
1065        }
1066
1067        if let Some(options) = &self.openssl_options {
1068            builder.clear_options(SslOptions::all());
1069            builder.set_options(*options);
1070        }
1071
1072        if self.insecure {
1073            builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1074        }
1075
1076        if !self.dane_tlsa.is_empty() {
1077            builder.dane_enable()?;
1078            builder.set_no_dane_ee_namechecks();
1079        }
1080
1081        let connector = builder.build();
1082
1083        let mut config = connector.configure()?;
1084
1085        if !self.dane_tlsa.is_empty() {
1086            config.dane_enable(hostname)?;
1087            let mut any_usable = false;
1088            for tlsa in &self.dane_tlsa {
1089                let usable = config.dane_tlsa_add(
1090                    match tlsa.cert_usage() {
1091                        CertUsage::PkixTa => DaneUsage::PKIX_TA,
1092                        CertUsage::PkixEe => DaneUsage::PKIX_EE,
1093                        CertUsage::DaneTa => DaneUsage::DANE_TA,
1094                        CertUsage::DaneEe => DaneUsage::DANE_EE,
1095                        CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1096                        CertUsage::Private => DaneUsage::PRIV_CERT,
1097                    },
1098                    match tlsa.selector() {
1099                        Selector::Full => DaneSelector::CERT,
1100                        Selector::Spki => DaneSelector::SPKI,
1101                        Selector::Unassigned(n) => DaneSelector::from_raw(n),
1102                        Selector::Private => DaneSelector::PRIV_SEL,
1103                    },
1104                    match tlsa.matching() {
1105                        Matching::Raw => DaneMatchType::FULL,
1106                        Matching::Sha256 => DaneMatchType::SHA2_256,
1107                        Matching::Sha512 => DaneMatchType::SHA2_512,
1108                        Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1109                        Matching::Private => DaneMatchType::PRIV_MATCH,
1110                    },
1111                    tlsa.cert_data(),
1112                )?;
1113
1114                tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1115                if usable {
1116                    any_usable = true;
1117                }
1118            }
1119
1120            if !any_usable {
1121                return Err(ClientError::NoUsableDaneTlsa {
1122                    hostname: hostname.to_string(),
1123                    tlsa: self.dane_tlsa.clone(),
1124                });
1125            }
1126        }
1127
1128        Ok(config)
1129    }
1130}
1131
1132fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1133    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1134
1135    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1136        return None;
1137    }
1138
1139    let mut stuffed = vec![];
1140    if data.starts_with(b".") {
1141        stuffed.push(b'.');
1142    }
1143    let mut last_idx = 0;
1144    for i in LFDOT.find_iter(data) {
1145        stuffed.extend_from_slice(&data[last_idx..=i]);
1146        stuffed.push(b'.');
1147        last_idx = i + 1;
1148    }
1149    stuffed.extend_from_slice(&data[last_idx..]);
1150    Some(stuffed)
1151}
1152
1153/// Extracts the object=name pairs of the subject name from a cert.
1154/// eg:
1155/// ```norun
1156/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1157/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1158/// ```
1159pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1160    let mut subject_name = vec![];
1161    for entry in cert.subject_name().entries() {
1162        if let Ok(obj) = entry.object().nid().short_name() {
1163            if let Ok(name) = entry.data().as_utf8() {
1164                subject_name.push(format!("{obj}={name}"));
1165            }
1166        }
1167    }
1168    subject_name
1169}
1170
1171#[cfg(test)]
1172mod test {
1173    use super::*;
1174
1175    #[test]
1176    fn test_stuffing() {
1177        assert_eq!(apply_dot_stuffing(b"foo"), None);
1178        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1179        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1180        assert_eq!(
1181            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1182            b"foo\n..bar\n...baz\n"
1183        );
1184    }
1185
1186    /*
1187    #[tokio::test]
1188    async fn test_against_sink() {
1189        use tokio::net::TcpStream;
1190        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1191        let mut client =
1192            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1193        dbg!(client.read_response(None).await).unwrap();
1194        dbg!(client.ehlo("localhost").await).unwrap();
1195        let insecure = true;
1196        dbg!(client.starttls(insecure).await).unwrap();
1197        let resp = client
1198            .send_mail(
1199                ReversePath::try_from("wez@mail.example.com").unwrap(),
1200                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1201                "Subject: hello\r\n\r\nwoot\r\n",
1202            )
1203            .await
1204            .unwrap();
1205        panic!("{resp:#?}");
1206    }
1207    */
1208
1209    #[test]
1210    fn response_line_parsing() {
1211        assert_eq!(
1212            parse_response_line("220 woot").unwrap(),
1213            ResponseLine {
1214                code: 220,
1215                is_final: true,
1216                content: "woot"
1217            }
1218        );
1219        assert_eq!(
1220            parse_response_line("220-woot").unwrap(),
1221            ResponseLine {
1222                code: 220,
1223                is_final: false,
1224                content: "woot"
1225            }
1226        );
1227
1228        assert!(matches!(
1229            parse_response_line("220_woot"),
1230            Err(ClientError::MalformedResponseLine(_))
1231        ));
1232        assert!(matches!(
1233            parse_response_line("not really"),
1234            Err(ClientError::MalformedResponseLine(_))
1235        ));
1236    }
1237
1238    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1239        let mut parsed = parse_response_line(lines[0])?;
1240        let mut b = ResponseBuilder::new(&parsed);
1241        for line in &lines[1..] {
1242            parsed = parse_response_line(line)?;
1243            b.add_line(&parsed)
1244                .map_err(ClientError::MalformedResponseLine)?;
1245        }
1246        assert!(parsed.is_final);
1247        Ok(b.build(None))
1248    }
1249
1250    #[test]
1251    fn multi_line_response() {
1252        assert_eq!(
1253            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1254            Response {
1255                code: 220,
1256                enhanced_code: None,
1257                content: "woot\nmore\ndone".to_string(),
1258                command: None
1259            }
1260        );
1261
1262        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1263        assert!(
1264            matches!(
1265                    res,
1266                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1267            ),
1268            "got error {res:?}"
1269        );
1270
1271        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1272        assert!(
1273            matches!(
1274                    res,
1275                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1276            ),
1277            "got error {res:?}"
1278        );
1279
1280        assert_eq!(
1281            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1282            Response {
1283                code: 220,
1284                enhanced_code: Some(EnhancedStatusCode {
1285                    class: 4,
1286                    subject: 1,
1287                    detail: 0
1288                }),
1289                content: "woot\nmore\ndone".to_string(),
1290                command: None
1291            }
1292        );
1293
1294        // Confirm that we strip the enhanced status code from each line
1295        assert_eq!(
1296            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1297            Response {
1298                code: 220,
1299                enhanced_code: Some(EnhancedStatusCode {
1300                    class: 4,
1301                    subject: 1,
1302                    detail: 0
1303                }),
1304                content: "woot\nmore\ndone".to_string(),
1305                command: None
1306            }
1307        );
1308
1309        // ... but only if the code matches that of the first line
1310        assert_eq!(
1311            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1312            Response {
1313                code: 220,
1314                enhanced_code: Some(EnhancedStatusCode {
1315                    class: 4,
1316                    subject: 1,
1317                    detail: 0
1318                }),
1319                content: "woot\nmore\n5.5.5 done".to_string(),
1320                command: None
1321            }
1322        );
1323    }
1324
1325    #[test]
1326    fn test_extract_hostname() {
1327        assert_eq!(extract_hostname("foo"), "foo");
1328        assert_eq!(extract_hostname("foo."), "foo");
1329        assert_eq!(extract_hostname("foo:25"), "foo");
1330        assert_eq!(extract_hostname("foo.:25"), "foo");
1331        assert_eq!(extract_hostname("[foo]:25"), "foo");
1332        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1333        assert_eq!(extract_hostname("[::1]:25"), "::1");
1334        assert_eq!(extract_hostname("::1:25"), "::1");
1335    }
1336}