rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
11use openssl::x509::{X509Ref, X509};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::sync::{Arc, LazyLock};
17use std::time::Duration;
18use thiserror::Error;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, ToSocketAddrs};
21use tokio::time::timeout;
22use tokio_rustls::rustls::pki_types::ServerName;
23use tracing::Level;
24
25pub use crate::tls::TlsOptions;
26pub use {openssl, tokio_rustls};
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32    #[error("response is not UTF8")]
33    Utf8(#[from] std::string::FromUtf8Error),
34    #[error("Malformed Response: {0}")]
35    MalformedResponseLine(String),
36    #[error("Response line is too long")]
37    ResponseTooLong,
38    #[error("Not connected")]
39    NotConnected,
40    #[error("Command rejected {0:?}")]
41    Rejected(Response),
42    #[error("STARTTLS: {0} is not a valid DNS name")]
43    InvalidDnsName(String),
44    #[error("Timed Out waiting {duration:?} for response to {command:?}")]
45    TimeOutResponse {
46        command: Option<Command>,
47        duration: Duration,
48    },
49    #[error("Timed Out writing {duration:?} {commands:?}")]
50    TimeOutRequest {
51        commands: Vec<Command>,
52        duration: Duration,
53    },
54    #[error("Error {error} reading response to {command:?}")]
55    ReadError {
56        command: Option<Command>,
57        error: String,
58        partial: String,
59    },
60    #[error("Error {error} flushing send buffer")]
61    FlushError { error: String },
62    #[error("Error {error} writing {commands:?}")]
63    WriteError {
64        commands: Vec<Command>,
65        error: String,
66    },
67    #[error("Timed Out sending message payload data")]
68    TimeOutData,
69    #[error("SSL Error: {0}")]
70    SslErrorStack(#[from] openssl::error::ErrorStack),
71    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
72    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
73}
74
75impl ClientError {
76    /// Returns the command(s) string suitable for passing into a Response
77    pub fn command(&self) -> Option<String> {
78        match self {
79            Self::TimeOutResponse {
80                command: Some(command),
81                ..
82            }
83            | Self::ReadError {
84                command: Some(command),
85                ..
86            } => Some(command.encode()),
87            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
88                if !commands.is_empty() =>
89            {
90                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
91                Some(commands.join(""))
92            }
93            _ => None,
94        }
95    }
96
97    /// If the error contents were likely caused by something
98    /// about the mostly recently attempted message, rather than
99    /// a transport issue, or a carry-over from a prior message
100    /// (eg: previous message was rejected and destination chose
101    /// to drop the connection, which we detect later in RSET
102    /// on the next message), then we return true.
103    /// The expectation is that the caller will transiently
104    /// fail the message for later retry.
105    /// If we return false then the caller might decide to
106    /// try that same message again more immediately on
107    /// a separate connection
108    pub fn was_due_to_message(&self) -> bool {
109        match self {
110            Self::Utf8(_)
111            | Self::MalformedResponseLine(_)
112            | Self::ResponseTooLong
113            | Self::NotConnected
114            | Self::InvalidDnsName(_)
115            | Self::TimeOutResponse { .. }
116            | Self::TimeOutRequest { .. }
117            | Self::ReadError { .. }
118            | Self::FlushError { .. }
119            | Self::WriteError { .. }
120            | Self::TimeOutData
121            | Self::SslErrorStack(_)
122            | Self::NoUsableDaneTlsa { .. } => false,
123            Self::Rejected(response) => response.was_due_to_message(),
124        }
125    }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub struct EsmtpCapability {
130    pub name: String,
131    pub param: Option<String>,
132}
133
134#[derive(Clone, Debug)]
135pub enum SmtpClientTraceEvent {
136    Closed,
137    Read(Vec<u8>),
138    Write(String),
139    Diagnostic {
140        level: tracing::Level,
141        message: String,
142    },
143}
144
145pub trait DeferredTracer {
146    fn trace(&self) -> SmtpClientTraceEvent;
147}
148
149pub trait SmtpClientTracer: std::fmt::Debug {
150    fn trace_event(&self, event: SmtpClientTraceEvent);
151    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
152}
153
154// helper to avoid making a second copy of every write buffer
155struct WriteTracer<'a> {
156    data: &'a str,
157}
158impl DeferredTracer for WriteTracer<'_> {
159    fn trace(&self) -> SmtpClientTraceEvent {
160        SmtpClientTraceEvent::Write(self.data.to_string())
161    }
162}
163impl<'a> WriteTracer<'a> {
164    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
165        tracer.lazy_trace(&Self { data });
166    }
167}
168
169struct BinWriteTracer<'a> {
170    data: &'a [u8],
171}
172impl DeferredTracer for BinWriteTracer<'_> {
173    fn trace(&self) -> SmtpClientTraceEvent {
174        let data = String::from_utf8_lossy(self.data).to_string();
175        SmtpClientTraceEvent::Write(data)
176    }
177}
178impl<'a> BinWriteTracer<'a> {
179    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
180        tracer.lazy_trace(&Self { data });
181    }
182}
183
184// A little bit of gymnastics to avoid making a second
185// copy of every read buffer
186struct ReadTracer<'a> {
187    data: &'a [u8],
188}
189impl DeferredTracer for ReadTracer<'_> {
190    fn trace(&self) -> SmtpClientTraceEvent {
191        SmtpClientTraceEvent::Read(self.data.to_vec())
192    }
193}
194
195#[derive(Debug)]
196pub struct SmtpClient {
197    socket: Option<BoxedAsyncReadAndWrite>,
198    hostname: String,
199    capabilities: HashMap<String, EsmtpCapability>,
200    read_buffer: Vec<u8>,
201    timeouts: SmtpClientTimeouts,
202    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
203    use_rset: bool,
204    enable_rset: bool,
205    enable_pipelining: bool,
206}
207
208fn extract_hostname(hostname: &str) -> &str {
209    // Just the hostname, without any :port
210    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
211    let hostname = if fields.len() == 2 {
212        fields[1]
213    } else {
214        hostname
215    };
216
217    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
218        &hostname[1..hostname.len() - 1]
219    } else {
220        hostname
221    };
222
223    // Remove any trailing FQDN dot
224    hostname.strip_suffix(".").unwrap_or(hostname)
225}
226
227impl SmtpClient {
228    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
229        addr: A,
230        timeouts: SmtpClientTimeouts,
231    ) -> std::io::Result<Self> {
232        let stream = TcpStream::connect(addr.clone()).await?;
233        // No need for Nagle with SMTP request/response
234        stream.set_nodelay(true)?;
235        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
236    }
237
238    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
239        stream: S,
240        peer_hostname: H,
241        timeouts: SmtpClientTimeouts,
242    ) -> Self {
243        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
244
245        Self {
246            socket: Some(Box::new(stream)),
247            hostname,
248            capabilities: HashMap::new(),
249            read_buffer: Vec::with_capacity(1024),
250            timeouts,
251            tracer: None,
252            use_rset: false,
253            enable_rset: false,
254            enable_pipelining: false,
255        }
256    }
257
258    pub fn is_connected(&self) -> bool {
259        self.socket.is_some()
260    }
261
262    pub fn set_enable_rset(&mut self, enable: bool) {
263        self.enable_rset = enable;
264    }
265
266    pub fn set_enable_pipelining(&mut self, enable: bool) {
267        self.enable_pipelining = enable;
268    }
269
270    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
271        self.tracer.replace(tracer);
272    }
273
274    pub fn timeouts(&self) -> &SmtpClientTimeouts {
275        &self.timeouts
276    }
277
278    async fn read_line(
279        &mut self,
280        timeout_duration: Duration,
281        cmd: Option<&Command>,
282    ) -> Result<String, ClientError> {
283        let mut too_long = false;
284        loop {
285            let mut iter = self.read_buffer.iter().enumerate();
286            while let Some((i, &b)) = iter.next() {
287                if b != b'\r' {
288                    continue;
289                }
290                if let Some((_, b'\n')) = iter.next() {
291                    if too_long {
292                        self.read_buffer.drain(0..i + 2);
293
294                        if let Some(tracer) = &self.tracer {
295                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
296                                level: Level::ERROR,
297                                message: "Response too long".to_string(),
298                            });
299                        }
300
301                        return Err(ClientError::ResponseTooLong);
302                    }
303
304                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
305                    self.read_buffer.drain(0..i + 2);
306                    return Ok(line?);
307                }
308            }
309            if self.read_buffer.len() > MAX_LINE_LEN {
310                self.read_buffer.clear();
311                too_long = true;
312            }
313
314            // Didn't find a complete line, fill up the rest of the buffer
315            let mut data = [0u8; MAX_LINE_LEN];
316            let size = match self.socket.as_mut() {
317                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
318                    Ok(Ok(size)) => size,
319                    Ok(Err(err)) => {
320                        self.socket.take();
321                        if let Some(tracer) = &self.tracer {
322                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
323                                level: Level::ERROR,
324                                message: format!("Error during read: {err:#}"),
325                            });
326                            tracer.trace_event(SmtpClientTraceEvent::Closed);
327                        }
328                        return Err(ClientError::ReadError {
329                            command: cmd.cloned(),
330                            error: format!("{err:#}"),
331                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
332                        });
333                    }
334                    Err(_) => {
335                        self.socket.take();
336                        if let Some(tracer) = &self.tracer {
337                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
338                                level: Level::ERROR,
339                                message: format!("Read Timeout after {timeout_duration:?}"),
340                            });
341                            tracer.trace_event(SmtpClientTraceEvent::Closed);
342                        }
343                        return Err(ClientError::TimeOutResponse {
344                            command: cmd.cloned(),
345                            duration: timeout_duration,
346                        });
347                    }
348                },
349                None => {
350                    return Err(ClientError::ReadError {
351                        command: cmd.cloned(),
352                        error: "the socket was closed in response to an earlier issue".to_string(),
353                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354                    });
355                }
356            };
357            if size == 0 {
358                self.socket.take();
359                if let Some(tracer) = &self.tracer {
360                    tracer.trace_event(SmtpClientTraceEvent::Closed);
361                }
362                return Err(ClientError::ReadError {
363                    command: cmd.cloned(),
364                    error: "Connection closed by peer".to_string(),
365                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
366                });
367            }
368            if let Some(tracer) = &self.tracer {
369                tracer.lazy_trace(&ReadTracer {
370                    data: &data[0..size],
371                });
372            }
373            self.read_buffer.extend_from_slice(&data[0..size]);
374        }
375    }
376
377    pub async fn read_response(
378        &mut self,
379        command: Option<&Command>,
380        timeout_duration: Duration,
381    ) -> Result<Response, ClientError> {
382        if let Some(sock) = self.socket.as_mut() {
383            if let Err(err) = sock.flush().await {
384                self.socket.take();
385                if let Some(tracer) = &self.tracer {
386                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
387                        level: Level::ERROR,
388                        message: format!("Error during flush: {err:#}"),
389                    });
390                    tracer.trace_event(SmtpClientTraceEvent::Closed);
391                }
392                return Err(ClientError::FlushError {
393                    error: format!("{err:#}"),
394                });
395            }
396        }
397
398        let mut line = self.read_line(timeout_duration, command).await?;
399        tracing::trace!("recv<-{}: {line}", self.hostname);
400        let mut parsed = parse_response_line(&line)?;
401        let mut response_builder = ResponseBuilder::new(&parsed);
402
403        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
404        while !parsed.is_final {
405            line = self
406                .read_line(subsequent_line_timeout_duration, command)
407                .await?;
408            parsed = parse_response_line(&line)?;
409            response_builder
410                .add_line(&parsed)
411                .map_err(ClientError::MalformedResponseLine)?;
412        }
413
414        let response = response_builder.build(command.map(|cmd| cmd.encode()));
415
416        tracing::trace!("{}: {response:?}", self.hostname);
417
418        Ok(response)
419    }
420
421    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
422        self.write_command_request(command).await?;
423        self.read_response(Some(command), command.client_timeout(&self.timeouts))
424            .await
425    }
426
427    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
428    /// a Close event to the tracer if a timeout or IO error occurs.
429    /// If an error or timeout, occurs ensures that the socket will not be reused.
430    async fn write_all_with_timeout<F, G>(
431        &mut self,
432        timeout_duration: Duration,
433        bytes: &[u8],
434        make_timeout_err: F,
435        make_write_err: G,
436    ) -> Result<(), ClientError>
437    where
438        F: FnOnce() -> ClientError,
439        G: FnOnce(String) -> ClientError,
440    {
441        match self.socket.as_mut() {
442            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
443                Ok(Ok(response)) => Ok(response),
444                Ok(Err(err)) => {
445                    self.socket.take();
446                    if let Some(tracer) = &self.tracer {
447                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
448                            level: Level::ERROR,
449                            message: format!("Error during write: {err:#}"),
450                        });
451                        tracer.trace_event(SmtpClientTraceEvent::Closed);
452                    }
453                    Err(make_write_err(format!("{err:#}")))
454                }
455                Err(_) => {
456                    self.socket.take();
457                    if let Some(tracer) = &self.tracer {
458                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
459                            level: Level::ERROR,
460                            message: format!("Write Timeout after {timeout_duration:?}"),
461                        });
462                        tracer.trace_event(SmtpClientTraceEvent::Closed);
463                    }
464                    Err(make_timeout_err())
465                }
466            },
467            None => Err(make_write_err(
468                "the socket was closed in response to an earlier issue".to_string(),
469            )),
470        }
471    }
472
473    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
474        let total_timeout: Duration = commands
475            .iter()
476            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
477            .sum();
478
479        let mut lines: Vec<String> = vec![];
480        let mut all = String::new();
481        for cmd in commands {
482            let line = cmd.encode();
483            all.push_str(&line);
484            lines.push(line);
485        }
486        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
487        if self.socket.is_some() {
488            if let Some(tracer) = &self.tracer {
489                // Send the lines individually to the tracer, so that we
490                // don't break --terse mode
491                for line in lines {
492                    WriteTracer::trace(tracer, &line);
493                }
494            }
495        }
496        self.write_all_with_timeout(
497            total_timeout,
498            all.as_bytes(),
499            || ClientError::TimeOutRequest {
500                duration: total_timeout,
501                commands: commands.to_vec(),
502            },
503            |error| ClientError::WriteError {
504                error,
505                commands: commands.to_vec(),
506            },
507        )
508        .await
509    }
510
511    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
512        let line = command.encode();
513        tracing::trace!("send->{}: {line}", self.hostname);
514        if self.socket.is_some() {
515            if let Some(tracer) = &self.tracer {
516                WriteTracer::trace(tracer, &line);
517            }
518        }
519
520        let timeout_duration = command.client_timeout_request(&self.timeouts);
521        self.write_all_with_timeout(
522            timeout_duration,
523            line.as_bytes(),
524            || ClientError::TimeOutRequest {
525                duration: timeout_duration,
526                commands: vec![command.clone()],
527            },
528            |error| ClientError::WriteError {
529                error,
530                commands: vec![command.clone()],
531            },
532        )
533        .await
534    }
535
536    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
537        if self.socket.is_some() {
538            if let Some(tracer) = &self.tracer {
539                BinWriteTracer::trace(tracer, data);
540            }
541        }
542        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
543        self.write_all_with_timeout(
544            timeout_duration,
545            data,
546            || ClientError::TimeOutData,
547            |error| ClientError::WriteError {
548                error,
549                commands: vec![],
550            },
551        )
552        .await
553    }
554
555    /// Issue a series of commands, and return the responses to
556    /// those commands.
557    ///
558    /// If the server advertised the RFC 2920 PIPELINING extension,
559    /// the commands are written one after the other before waiting
560    /// to read any data, resulting in lower overall latency due
561    /// to round-trip-times.
562    ///
563    /// If PIPELINING is not available, each command is written
564    /// and the response read before attempting to write the next
565    /// command.
566    ///
567    /// The number of returned responses may be smaller than the
568    /// number of requested commands if there is an issue with
569    /// the network connection.
570    pub async fn pipeline_commands(
571        &mut self,
572        commands: Vec<Command>,
573    ) -> Vec<Result<Response, ClientError>> {
574        let mut results: Vec<Result<Response, ClientError>> = vec![];
575
576        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
577        if pipeline {
578            if let Err(err) = self.write_pipeline_request(&commands).await {
579                let err: ClientError = err;
580                results.push(Err(err.clone()));
581                while results.len() < commands.len() {
582                    // Synthesize failures for the remaining commands
583                    results.push(Err(err.clone()));
584                }
585                return results;
586            }
587
588            // Now read the responses effectively in a batch
589            for cmd in &commands {
590                results.push(
591                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
592                        .await,
593                );
594            }
595            return results;
596        }
597
598        for cmd in &commands {
599            if let Err(err) = self.write_command_request(cmd).await {
600                let err: ClientError = err;
601                results.push(Err(err.clone()));
602                while results.len() < commands.len() {
603                    // Synthesize failures for the remaining commands
604                    results.push(Err(err.clone()));
605                }
606                return results;
607            }
608            // Immediately request the response if the server
609            // doesn't support pipelining
610            results.push(
611                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
612                    .await,
613            );
614        }
615        results
616    }
617
618    pub async fn ehlo_lhlo(
619        &mut self,
620        ehlo_name: &str,
621        use_lmtp: bool,
622    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
623        if use_lmtp {
624            self.lhlo(ehlo_name).await
625        } else {
626            self.ehlo(ehlo_name).await
627        }
628    }
629
630    pub async fn lhlo(
631        &mut self,
632        ehlo_name: &str,
633    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
634        let response = self
635            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
636            .await?;
637        self.ehlo_common(response)
638    }
639
640    pub async fn ehlo(
641        &mut self,
642        ehlo_name: &str,
643    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
644        let response = self
645            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
646            .await?;
647        self.ehlo_common(response)
648    }
649
650    fn ehlo_common(
651        &mut self,
652        response: Response,
653    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
654        if response.code != 250 {
655            return Err(ClientError::Rejected(response));
656        }
657
658        let mut capabilities = HashMap::new();
659
660        for line in response.content.lines().skip(1) {
661            let mut fields = line.splitn(2, ' ');
662            if let Some(name) = fields.next() {
663                let param = fields.next().map(|s| s.to_string());
664                let cap = EsmtpCapability {
665                    name: name.to_string(),
666                    param,
667                };
668                capabilities.insert(name.to_ascii_uppercase(), cap);
669            }
670        }
671
672        self.capabilities = capabilities;
673        Ok(&self.capabilities)
674    }
675
676    pub async fn auth_plain(
677        &mut self,
678        username: &str,
679        password: Option<&str>,
680    ) -> Result<(), ClientError> {
681        // RFC 4616 says that the format is:
682        // [authzid] NUL authcid NUL passwd
683        let password = password.unwrap_or("");
684        let payload = format!("\x00{username}\x00{password}");
685        let payload = data_encoding::BASE64.encode(payload.as_bytes());
686
687        let response = self
688            .send_command(&Command::Auth {
689                sasl_mech: "PLAIN".to_string(),
690                initial_response: Some(payload),
691            })
692            .await?;
693
694        if response.code != 235 {
695            return Err(ClientError::Rejected(response));
696        }
697
698        Ok(())
699    }
700
701    /// Attempt TLS handshake.
702    /// Returns Err for IO errors.
703    /// On completion, return an option that will be:
704    /// * Some(handshake_error) - if the handshake failed
705    /// * None - if the handshake succeeded
706    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
707        let resp = self.send_command(&Command::StartTls).await?;
708        if resp.code != 220 {
709            return Err(ClientError::Rejected(resp));
710        }
711
712        let mut handshake_error = None;
713        let mut tls_info = TlsInformation::default();
714
715        let stream: BoxedAsyncReadAndWrite =
716            if options.prefer_openssl || !options.dane_tlsa.is_empty() {
717                let connector = options.build_openssl_connector(&self.hostname)?;
718                let ssl = connector.into_ssl(self.hostname.as_str())?;
719
720                let (stream, dup_stream) = match self.socket.take() {
721                    Some(s) => {
722                        let d = s.try_dup();
723                        (s, d)
724                    }
725                    None => return Err(ClientError::NotConnected),
726                };
727
728                let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
729
730                if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
731                    handshake_error.replace(format!("{err:#}"));
732                }
733
734                tls_info.provider_name = "openssl".to_string();
735                tls_info.cipher = match ssl_stream.ssl().current_cipher() {
736                    Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
737                    None => String::new(),
738                };
739                tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
740
741                if let Some(cert) = ssl_stream.ssl().peer_certificate() {
742                    tls_info.subject_name = subject_name(&cert);
743                }
744                if let Ok(authority) = ssl_stream.ssl().dane_authority() {
745                    if let Some(cert) = &authority.cert {
746                        tls_info.subject_name = subject_name(cert);
747                    }
748                }
749
750                match (&handshake_error, dup_stream) {
751                    (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
752                        // Try falling back to clear text on the duplicate stream.
753                        // This is imperfect: in a failed validation scenario we will
754                        // end up trying to read binary data as a string and get a UTF-8
755                        // error if the peer thinks the session is encrypted.
756                        drop(ssl_stream);
757                        Box::new(dup_stream)
758                    }
759                    _ => Box::new(ssl_stream),
760                }
761            } else {
762                tls_info.provider_name = "rustls".to_string();
763                let connector = options.build_tls_connector().await;
764                let server_name = match IpAddr::from_str(self.hostname.as_str()) {
765                    Ok(ip) => ServerName::IpAddress(ip.into()),
766                    Err(_) => ServerName::try_from(self.hostname.clone())
767                        .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
768                };
769
770                match connector
771                    .connect(
772                        server_name,
773                        match self.socket.take() {
774                            Some(s) => s,
775                            None => return Err(ClientError::NotConnected),
776                        },
777                    )
778                    .into_fallible()
779                    .await
780                {
781                    Ok(stream) => {
782                        let (_, conn) = stream.get_ref();
783                        tls_info.cipher = match conn.negotiated_cipher_suite() {
784                            Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
785                            None => String::new(),
786                        };
787                        tls_info.protocol_version = match conn.protocol_version() {
788                            Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
789                            None => String::new(),
790                        };
791
792                        if let Some(certs) = conn.peer_certificates() {
793                            let peer_cert = &certs[0];
794                            if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
795                                tls_info.subject_name = subject_name(&cert);
796                            }
797                        }
798
799                        Box::new(stream)
800                    }
801                    Err((err, stream)) => {
802                        handshake_error.replace(format!("{err:#}"));
803                        stream
804                    }
805                }
806            };
807
808        if let Some(tracer) = &self.tracer {
809            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
810                level: Level::INFO,
811                message: match &handshake_error {
812                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
813                    None => format!("STARTTLS handshake -> {tls_info:?}"),
814                },
815            });
816        }
817
818        self.socket.replace(stream);
819        Ok(match handshake_error {
820            Some(error) => TlsStatus::FailedHandshake(error),
821            None => TlsStatus::Info(tls_info),
822        })
823    }
824
825    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
826        &mut self,
827        sender: SENDER,
828        recipient: RECIP,
829        data: B,
830    ) -> Result<Response, ClientError> {
831        let sender = sender.into();
832        let recipient = recipient.into();
833
834        let data: &[u8] = data.as_ref();
835        let stuffed;
836
837        let data = match apply_dot_stuffing(data) {
838            Some(d) => {
839                stuffed = d;
840                &stuffed
841            }
842            None => data,
843        };
844
845        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
846        let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
847
848        let mut mail_from_params = vec![];
849        if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
850            mail_from_params.push(EsmtpParameter {
851                name: "BODY".to_string(),
852                value: Some("8BITMIME".to_string()),
853            });
854        }
855
856        if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
857            mail_from_params.push(EsmtpParameter {
858                name: "SMTPUTF8".to_string(),
859                value: None,
860            });
861        }
862
863        let mut commands = vec![];
864
865        // We want to avoid using RSET for the first message we send on
866        // a given connection, because postfix can run in a mode where
867        // it will not tolerate RSET because it considers it to be a "junk"
868        // command, and rejecting junk commands will cut down on its load
869        // when it is under stress; it is used as a load shedding approach.
870        // If we always RSET then we will never deliver to a site that is
871        // configured that way. If we take care to RSET only for subsequent
872        // sends, then we should get one message per connection through
873        // without being unfairly penalized for defensively RSETing.
874        let used_rset = self.use_rset;
875        if self.use_rset {
876            commands.push(Command::Rset);
877        }
878        commands.push(Command::MailFrom {
879            address: sender,
880            parameters: mail_from_params,
881        });
882        commands.push(Command::RcptTo {
883            address: recipient,
884            parameters: vec![],
885        });
886        commands.push(Command::Data);
887
888        // Assume that something might break below: if it does, we want
889        // to ensure that we RSET the connection on the next go around.
890        self.use_rset = true;
891
892        let mut responses = self.pipeline_commands(commands).await;
893
894        // This is a little awkward. We want to handle the RFC 2090 3.1 case
895        // below, which requires deferring checking the actual response codes
896        // until later, but we must also handle the case where we had a hard
897        // transport error partway through pipelining.
898        // So we set a flag for that case and will then "eagerly", wrt. the
899        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
900        // can propagate the correct error disposition up to the caller.
901        let is_err = responses.iter().any(|r| r.is_err());
902
903        if used_rset {
904            let rset_resp = responses.remove(0)?;
905            if rset_resp.code != 250 {
906                return Err(ClientError::Rejected(rset_resp));
907            }
908        }
909
910        let mail_resp = responses.remove(0)?;
911        if is_err && mail_resp.code != 250 {
912            return Err(ClientError::Rejected(mail_resp));
913        }
914
915        let rcpt_resp = responses.remove(0)?;
916        if is_err && rcpt_resp.code != 250 {
917            return Err(ClientError::Rejected(rcpt_resp));
918        }
919
920        let data_resp = responses.remove(0)?;
921        if is_err && data_resp.code != 354 {
922            return Err(ClientError::Rejected(data_resp));
923        }
924
925        if data_resp.code == 354 && (mail_resp.code != 250 || rcpt_resp.code != 250) {
926            // RFC 2920 3.1:
927            // the client cannot assume that the DATA command will be rejected
928            // just because none of the RCPT TO commands worked.  If the DATA
929            // command was properly rejected the client SMTP can just issue
930            // RSET, but if the DATA command was accepted the client SMTP
931            // should send a single dot.
932
933            // Send dummy data
934            self.write_data_with_timeout(b".\r\n").await?;
935            let data_dot = Command::DataDot;
936            // wait for its response
937            let _ = self
938                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
939                .await?;
940
941            // Continue below: we will match one of the failure cases and
942            // return a ClientError::Rejected from one of the earlier
943            // commands
944        }
945
946        if mail_resp.code != 250 {
947            return Err(ClientError::Rejected(mail_resp));
948        }
949        if rcpt_resp.code != 250 {
950            return Err(ClientError::Rejected(rcpt_resp));
951        }
952        if data_resp.code != 354 {
953            return Err(ClientError::Rejected(data_resp));
954        }
955
956        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
957
958        tracing::trace!("message data is {} bytes", data.len());
959
960        self.write_data_with_timeout(data).await?;
961
962        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
963
964        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
965
966        self.write_data_with_timeout(marker.as_bytes()).await?;
967
968        let data_dot = Command::DataDot;
969        let resp = self
970            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
971            .await?;
972        if resp.code != 250 {
973            return Err(ClientError::Rejected(resp));
974        }
975
976        // If everything went well, respect the user preference for speculatively
977        // issuing an RSET next time around
978        self.use_rset = self.enable_rset;
979
980        Ok(resp)
981    }
982}
983
984#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
985pub enum TlsStatus {
986    FailedHandshake(String),
987    Info(TlsInformation),
988}
989
990#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
991pub struct TlsInformation {
992    pub cipher: String,
993    pub protocol_version: String,
994    pub subject_name: Vec<String>,
995    pub provider_name: String,
996}
997
998impl Drop for SmtpClient {
999    fn drop(&mut self) {
1000        if let Some(tracer) = &self.tracer {
1001            if self.socket.is_some() {
1002                tracer.trace_event(SmtpClientTraceEvent::Closed);
1003            }
1004        }
1005    }
1006}
1007fn parse_response_line(line: &str) -> Result<ResponseLine, ClientError> {
1008    if line.len() < 4 {
1009        return Err(ClientError::MalformedResponseLine(line.to_string()));
1010    }
1011
1012    match line.as_bytes()[3] {
1013        b' ' | b'-' => match line[0..3].parse::<u16>() {
1014            Ok(code) => Ok(ResponseLine {
1015                code,
1016                is_final: line.as_bytes()[3] == b' ',
1017                content: &line[4..],
1018            }),
1019            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1020        },
1021        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1022    }
1023}
1024
1025impl TlsOptions {
1026    pub fn build_openssl_connector(
1027        &self,
1028        hostname: &str,
1029    ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1030        tracing::trace!("build_openssl_connector for {hostname}");
1031        let mut builder =
1032            openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1033
1034        if let Some(list) = &self.openssl_cipher_list {
1035            builder.set_cipher_list(list)?;
1036        }
1037
1038        if let Some(suites) = &self.openssl_cipher_suites {
1039            builder.set_ciphersuites(suites)?;
1040        }
1041
1042        if let Some(options) = &self.openssl_options {
1043            builder.clear_options(SslOptions::all());
1044            builder.set_options(*options);
1045        }
1046
1047        if self.insecure {
1048            builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1049        }
1050
1051        if !self.dane_tlsa.is_empty() {
1052            builder.dane_enable()?;
1053            builder.set_no_dane_ee_namechecks();
1054        }
1055
1056        let connector = builder.build();
1057
1058        let mut config = connector.configure()?;
1059
1060        if !self.dane_tlsa.is_empty() {
1061            config.dane_enable(hostname)?;
1062            let mut any_usable = false;
1063            for tlsa in &self.dane_tlsa {
1064                let usable = config.dane_tlsa_add(
1065                    match tlsa.cert_usage() {
1066                        CertUsage::PkixTa => DaneUsage::PKIX_TA,
1067                        CertUsage::PkixEe => DaneUsage::PKIX_EE,
1068                        CertUsage::DaneTa => DaneUsage::DANE_TA,
1069                        CertUsage::DaneEe => DaneUsage::DANE_EE,
1070                        CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1071                        CertUsage::Private => DaneUsage::PRIV_CERT,
1072                    },
1073                    match tlsa.selector() {
1074                        Selector::Full => DaneSelector::CERT,
1075                        Selector::Spki => DaneSelector::SPKI,
1076                        Selector::Unassigned(n) => DaneSelector::from_raw(n),
1077                        Selector::Private => DaneSelector::PRIV_SEL,
1078                    },
1079                    match tlsa.matching() {
1080                        Matching::Raw => DaneMatchType::FULL,
1081                        Matching::Sha256 => DaneMatchType::SHA2_256,
1082                        Matching::Sha512 => DaneMatchType::SHA2_512,
1083                        Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1084                        Matching::Private => DaneMatchType::PRIV_MATCH,
1085                    },
1086                    tlsa.cert_data(),
1087                )?;
1088
1089                tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1090                if usable {
1091                    any_usable = true;
1092                }
1093            }
1094
1095            if !any_usable {
1096                return Err(ClientError::NoUsableDaneTlsa {
1097                    hostname: hostname.to_string(),
1098                    tlsa: self.dane_tlsa.clone(),
1099                });
1100            }
1101        }
1102
1103        Ok(config)
1104    }
1105}
1106
1107fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1108    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1109
1110    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1111        return None;
1112    }
1113
1114    let mut stuffed = vec![];
1115    if data.starts_with(b".") {
1116        stuffed.push(b'.');
1117    }
1118    let mut last_idx = 0;
1119    for i in LFDOT.find_iter(data) {
1120        stuffed.extend_from_slice(&data[last_idx..=i]);
1121        stuffed.push(b'.');
1122        last_idx = i + 1;
1123    }
1124    stuffed.extend_from_slice(&data[last_idx..]);
1125    Some(stuffed)
1126}
1127
1128/// Extracts the object=name pairs of the subject name from a cert.
1129/// eg:
1130/// ```norun
1131/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1132/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1133/// ```
1134fn subject_name(cert: &X509Ref) -> Vec<String> {
1135    let mut subject_name = vec![];
1136    for entry in cert.subject_name().entries() {
1137        if let Ok(obj) = entry.object().nid().short_name() {
1138            if let Ok(name) = entry.data().as_utf8() {
1139                subject_name.push(format!("{obj}={name}"));
1140            }
1141        }
1142    }
1143    subject_name
1144}
1145
1146#[cfg(test)]
1147mod test {
1148    use super::*;
1149
1150    #[test]
1151    fn test_stuffing() {
1152        assert_eq!(apply_dot_stuffing(b"foo"), None);
1153        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1154        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1155        assert_eq!(
1156            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1157            b"foo\n..bar\n...baz\n"
1158        );
1159    }
1160
1161    /*
1162    #[tokio::test]
1163    async fn test_against_sink() {
1164        use tokio::net::TcpStream;
1165        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1166        let mut client =
1167            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1168        dbg!(client.read_response(None).await).unwrap();
1169        dbg!(client.ehlo("localhost").await).unwrap();
1170        let insecure = true;
1171        dbg!(client.starttls(insecure).await).unwrap();
1172        let resp = client
1173            .send_mail(
1174                ReversePath::try_from("wez@mail.example.com").unwrap(),
1175                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1176                "Subject: hello\r\n\r\nwoot\r\n",
1177            )
1178            .await
1179            .unwrap();
1180        panic!("{resp:#?}");
1181    }
1182    */
1183
1184    #[test]
1185    fn response_line_parsing() {
1186        assert_eq!(
1187            parse_response_line("220 woot").unwrap(),
1188            ResponseLine {
1189                code: 220,
1190                is_final: true,
1191                content: "woot"
1192            }
1193        );
1194        assert_eq!(
1195            parse_response_line("220-woot").unwrap(),
1196            ResponseLine {
1197                code: 220,
1198                is_final: false,
1199                content: "woot"
1200            }
1201        );
1202
1203        assert!(matches!(
1204            parse_response_line("220_woot"),
1205            Err(ClientError::MalformedResponseLine(_))
1206        ));
1207        assert!(matches!(
1208            parse_response_line("not really"),
1209            Err(ClientError::MalformedResponseLine(_))
1210        ));
1211    }
1212
1213    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1214        let mut parsed = parse_response_line(lines[0])?;
1215        let mut b = ResponseBuilder::new(&parsed);
1216        for line in &lines[1..] {
1217            parsed = parse_response_line(line)?;
1218            b.add_line(&parsed)
1219                .map_err(ClientError::MalformedResponseLine)?;
1220        }
1221        assert!(parsed.is_final);
1222        Ok(b.build(None))
1223    }
1224
1225    #[test]
1226    fn multi_line_response() {
1227        assert_eq!(
1228            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1229            Response {
1230                code: 220,
1231                enhanced_code: None,
1232                content: "woot\nmore\ndone".to_string(),
1233                command: None
1234            }
1235        );
1236
1237        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1238        assert!(
1239            matches!(
1240                    res,
1241                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1242            ),
1243            "got error {res:?}"
1244        );
1245
1246        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1247        assert!(
1248            matches!(
1249                    res,
1250                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1251            ),
1252            "got error {res:?}"
1253        );
1254
1255        assert_eq!(
1256            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1257            Response {
1258                code: 220,
1259                enhanced_code: Some(EnhancedStatusCode {
1260                    class: 4,
1261                    subject: 1,
1262                    detail: 0
1263                }),
1264                content: "woot\nmore\ndone".to_string(),
1265                command: None
1266            }
1267        );
1268
1269        // Confirm that we strip the enhanced status code from each line
1270        assert_eq!(
1271            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1272            Response {
1273                code: 220,
1274                enhanced_code: Some(EnhancedStatusCode {
1275                    class: 4,
1276                    subject: 1,
1277                    detail: 0
1278                }),
1279                content: "woot\nmore\ndone".to_string(),
1280                command: None
1281            }
1282        );
1283
1284        // ... but only if the code matches that of the first line
1285        assert_eq!(
1286            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1287            Response {
1288                code: 220,
1289                enhanced_code: Some(EnhancedStatusCode {
1290                    class: 4,
1291                    subject: 1,
1292                    detail: 0
1293                }),
1294                content: "woot\nmore\n5.5.5 done".to_string(),
1295                command: None
1296            }
1297        );
1298    }
1299
1300    #[test]
1301    fn test_extract_hostname() {
1302        assert_eq!(extract_hostname("foo"), "foo");
1303        assert_eq!(extract_hostname("foo."), "foo");
1304        assert_eq!(extract_hostname("foo:25"), "foo");
1305        assert_eq!(extract_hostname("foo.:25"), "foo");
1306        assert_eq!(extract_hostname("[foo]:25"), "foo");
1307        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1308        assert_eq!(extract_hostname("[::1]:25"), "::1");
1309        assert_eq!(extract_hostname("::1:25"), "::1");
1310    }
1311}