rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::parser::{Command, Domain, EsmtpParameter, ForwardPath, ReversePath};
4use crate::{AsyncReadAndWrite, BoxedAsyncReadAndWrite};
5use bstr::ByteSlice;
6use hickory_proto::rr::rdata::TLSA;
7use memchr::memmem::Finder;
8use nom_utils::DomainString;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use openssl;
25pub use tokio_rustls;
26
27const MAX_LINE_LEN: usize = 4096;
28
29#[derive(Error, Debug, Clone)]
30pub enum ClientError {
31    #[error("response is not UTF8")]
32    Utf8(#[from] std::string::FromUtf8Error),
33    #[error("Malformed Response: {0}")]
34    MalformedResponseLine(String),
35    #[error("Response line is too long")]
36    ResponseTooLong,
37    #[error("Not connected")]
38    NotConnected,
39    #[error("Command rejected {0:?}")]
40    Rejected(Response),
41    #[error("Commands rejected {0:?}")]
42    RejectedBatch(Vec<Response>),
43    #[error("STARTTLS: {0} is not a valid DNS name")]
44    InvalidDnsName(String),
45    #[error("Invalid client certificate configured: {error:?}")]
46    FailedToBuildConnector { error: String },
47    #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
48    TimeOutResponse {
49        command: Option<Command>,
50        duration: Duration,
51    },
52    #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
53    TimeOutRequest {
54        commands: Vec<Command>,
55        duration: Duration,
56    },
57    #[error("Error {error} reading response to cmd={}", self.command_for_err())]
58    ReadError {
59        command: Option<Command>,
60        error: String,
61        partial: String,
62    },
63    #[error("Error {error} flushing send buffer")]
64    FlushError { error: String },
65    #[error("Error {error} writing {}", self.command_for_err())]
66    WriteError {
67        commands: Vec<Command>,
68        error: String,
69    },
70    #[error("Timed Out sending message payload data")]
71    TimeOutData,
72    #[error("SSL Error: {0}")]
73    SslErrorStack(#[from] openssl::error::ErrorStack),
74    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79    /// Returns the command(s) string suitable for passing into a Response
80    pub fn command(&self) -> Option<String> {
81        match self {
82            Self::TimeOutResponse {
83                command: Some(command),
84                ..
85            }
86            | Self::ReadError {
87                command: Some(command),
88                ..
89            } => Some(command.encode().to_string()),
90            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91                if !commands.is_empty() =>
92            {
93                let s: String = commands
94                    .iter()
95                    .map(|cmd| cmd.encode().to_string())
96                    .collect();
97                Some(s)
98            }
99            _ => None,
100        }
101    }
102
103    /// Returns the command(s) string suitable for embedding in the
104    /// overall error message
105    fn command_for_err(&self) -> String {
106        self.command()
107            .map(|cmd| cmd.replace("\r\n", ""))
108            .unwrap_or_else(|| "NONE".to_string())
109    }
110
111    /// If the error contents were likely caused by something
112    /// about the mostly recently attempted message, rather than
113    /// a transport issue, or a carry-over from a prior message
114    /// (eg: previous message was rejected and destination chose
115    /// to drop the connection, which we detect later in RSET
116    /// on the next message), then we return true.
117    /// The expectation is that the caller will transiently
118    /// fail the message for later retry.
119    /// If we return false then the caller might decide to
120    /// try that same message again more immediately on
121    /// a separate connection
122    pub fn was_due_to_message(&self) -> bool {
123        match self {
124            Self::Utf8(_)
125            | Self::MalformedResponseLine(_)
126            | Self::ResponseTooLong
127            | Self::NotConnected
128            | Self::InvalidDnsName(_)
129            | Self::FailedToBuildConnector { .. }
130            | Self::TimeOutResponse { .. }
131            | Self::TimeOutRequest { .. }
132            | Self::ReadError { .. }
133            | Self::FlushError { .. }
134            | Self::WriteError { .. }
135            | Self::TimeOutData
136            | Self::SslErrorStack(_)
137            | Self::NoUsableDaneTlsa { .. } => false,
138            Self::Rejected(response) => response.was_due_to_message(),
139            Self::RejectedBatch(responses) => match responses.len() {
140                1 => responses[0].was_due_to_message(),
141                _ => false,
142            },
143        }
144    }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct EsmtpCapability {
149    pub name: String,
150    pub param: Option<String>,
151}
152
153#[derive(Clone, Debug)]
154pub enum SmtpClientTraceEvent {
155    Closed,
156    Read(Vec<u8>),
157    Write(String),
158    Diagnostic {
159        level: tracing::Level,
160        message: String,
161    },
162}
163
164pub trait DeferredTracer {
165    fn trace(&self) -> SmtpClientTraceEvent;
166}
167
168pub trait SmtpClientTracer: std::fmt::Debug {
169    fn trace_event(&self, event: SmtpClientTraceEvent);
170    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
171}
172
173// helper to avoid making a second copy of every write buffer
174struct WriteTracer<'a> {
175    data: &'a str,
176}
177impl DeferredTracer for WriteTracer<'_> {
178    fn trace(&self) -> SmtpClientTraceEvent {
179        SmtpClientTraceEvent::Write(self.data.to_string())
180    }
181}
182impl<'a> WriteTracer<'a> {
183    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
184        tracer.lazy_trace(&Self { data });
185    }
186}
187
188struct BinWriteTracer<'a> {
189    data: &'a [u8],
190}
191impl DeferredTracer for BinWriteTracer<'_> {
192    fn trace(&self) -> SmtpClientTraceEvent {
193        let data = String::from_utf8_lossy(self.data).to_string();
194        SmtpClientTraceEvent::Write(data)
195    }
196}
197impl<'a> BinWriteTracer<'a> {
198    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
199        tracer.lazy_trace(&Self { data });
200    }
201}
202
203// A little bit of gymnastics to avoid making a second
204// copy of every read buffer
205struct ReadTracer<'a> {
206    data: &'a [u8],
207}
208impl DeferredTracer for ReadTracer<'_> {
209    fn trace(&self) -> SmtpClientTraceEvent {
210        SmtpClientTraceEvent::Read(self.data.to_vec())
211    }
212}
213
214#[derive(Debug)]
215pub struct SmtpClient {
216    socket: Option<BoxedAsyncReadAndWrite>,
217    hostname: String,
218    capabilities: HashMap<String, EsmtpCapability>,
219    read_buffer: Vec<u8>,
220    timeouts: SmtpClientTimeouts,
221    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
222    use_rset: bool,
223    enable_rset: bool,
224    enable_pipelining: bool,
225    ignore_8bit_checks: bool,
226}
227
228fn extract_hostname(hostname: &str) -> &str {
229    // Just the hostname, without any :port
230    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
231    let hostname = if fields.len() == 2 {
232        fields[1]
233    } else {
234        hostname
235    };
236
237    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
238        &hostname[1..hostname.len() - 1]
239    } else {
240        hostname
241    };
242
243    // Remove any trailing FQDN dot
244    hostname.strip_suffix(".").unwrap_or(hostname)
245}
246
247impl SmtpClient {
248    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
249        addr: A,
250        timeouts: SmtpClientTimeouts,
251    ) -> std::io::Result<Self> {
252        let stream = TcpStream::connect(addr.clone()).await?;
253        // No need for Nagle with SMTP request/response
254        stream.set_nodelay(true)?;
255        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
256    }
257
258    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
259        stream: S,
260        peer_hostname: H,
261        timeouts: SmtpClientTimeouts,
262    ) -> Self {
263        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
264
265        Self {
266            socket: Some(Box::new(stream)),
267            hostname,
268            capabilities: HashMap::new(),
269            read_buffer: Vec::with_capacity(1024),
270            timeouts,
271            tracer: None,
272            use_rset: false,
273            enable_rset: false,
274            enable_pipelining: false,
275            ignore_8bit_checks: false,
276        }
277    }
278
279    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
280    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
281        self.ignore_8bit_checks = enable;
282    }
283
284    pub fn is_connected(&self) -> bool {
285        self.socket.is_some()
286    }
287
288    pub fn set_enable_rset(&mut self, enable: bool) {
289        self.enable_rset = enable;
290    }
291
292    pub fn set_enable_pipelining(&mut self, enable: bool) {
293        self.enable_pipelining = enable;
294    }
295
296    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
297        self.tracer.replace(tracer);
298    }
299
300    pub fn timeouts(&self) -> &SmtpClientTimeouts {
301        &self.timeouts
302    }
303
304    async fn read_line(
305        &mut self,
306        timeout_duration: Duration,
307        cmd: Option<&Command>,
308    ) -> Result<String, ClientError> {
309        let mut too_long = false;
310        loop {
311            let mut iter = self.read_buffer.iter().enumerate();
312            while let Some((i, &b)) = iter.next() {
313                if b != b'\r' {
314                    continue;
315                }
316                if let Some((_, b'\n')) = iter.next() {
317                    if too_long {
318                        self.read_buffer.drain(0..i + 2);
319
320                        if let Some(tracer) = &self.tracer {
321                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
322                                level: Level::ERROR,
323                                message: "Response too long".to_string(),
324                            });
325                        }
326
327                        return Err(ClientError::ResponseTooLong);
328                    }
329
330                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
331                    self.read_buffer.drain(0..i + 2);
332                    return Ok(line?);
333                }
334            }
335            if self.read_buffer.len() > MAX_LINE_LEN {
336                self.read_buffer.clear();
337                too_long = true;
338            }
339
340            // Didn't find a complete line, fill up the rest of the buffer
341            let mut data = [0u8; MAX_LINE_LEN];
342            let size = match self.socket.as_mut() {
343                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
344                    Ok(Ok(size)) => size,
345                    Ok(Err(err)) => {
346                        self.socket.take();
347                        if let Some(tracer) = &self.tracer {
348                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
349                                level: Level::ERROR,
350                                message: format!("Error during read: {err:#}"),
351                            });
352                            tracer.trace_event(SmtpClientTraceEvent::Closed);
353                        }
354                        return Err(ClientError::ReadError {
355                            command: cmd.cloned(),
356                            error: format!("{err:#}"),
357                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
358                        });
359                    }
360                    Err(_) => {
361                        // If we're using a zero timeout we're probing
362                        // for liveness, so we don't want to treat this
363                        // timeout as a fatal error that closes the
364                        // connection
365                        if timeout_duration != Duration::ZERO {
366                            self.socket.take();
367                        }
368                        if let Some(tracer) = &self.tracer {
369                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
370                                level: Level::ERROR,
371                                message: format!("Read Timeout after {timeout_duration:?}"),
372                            });
373                            if self.socket.is_none() {
374                                tracer.trace_event(SmtpClientTraceEvent::Closed);
375                            }
376                        }
377                        return Err(ClientError::TimeOutResponse {
378                            command: cmd.cloned(),
379                            duration: timeout_duration,
380                        });
381                    }
382                },
383                None => {
384                    return Err(ClientError::ReadError {
385                        command: cmd.cloned(),
386                        error: "the socket was closed in response to an earlier issue".to_string(),
387                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
388                    });
389                }
390            };
391            if size == 0 {
392                self.socket.take();
393                if let Some(tracer) = &self.tracer {
394                    tracer.trace_event(SmtpClientTraceEvent::Closed);
395                }
396                return Err(ClientError::ReadError {
397                    command: cmd.cloned(),
398                    error: "Connection closed by peer".to_string(),
399                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
400                });
401            }
402            if let Some(tracer) = &self.tracer {
403                tracer.lazy_trace(&ReadTracer {
404                    data: &data[0..size],
405                });
406            }
407            self.read_buffer.extend_from_slice(&data[0..size]);
408        }
409    }
410
411    /// Check to see if there is either a unilateral response
412    /// or an error condition immediately available on the socket.
413    ///
414    /// Intended to be used after the connection has idled to see if
415    /// the peer has opted to close the connection, either gracefully
416    /// by sending what should be a 421, or abruptly by simply snipping
417    /// the connection.
418    pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
419        match self.read_response(None, Duration::ZERO).await {
420            // Unilateral response was received
421            Ok(response) => Ok(Some(response)),
422            // Nothing immediately available: this should hopefully be the common case
423            Err(ClientError::TimeOutResponse { .. }) => Ok(None),
424            // An actual error
425            Err(err) => Err(err),
426        }
427    }
428
429    pub async fn read_response(
430        &mut self,
431        command: Option<&Command>,
432        timeout_duration: Duration,
433    ) -> Result<Response, ClientError> {
434        if let Some(sock) = self.socket.as_mut() {
435            if let Err(err) = sock.flush().await {
436                self.socket.take();
437                if let Some(tracer) = &self.tracer {
438                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
439                        level: Level::ERROR,
440                        message: format!("Error during flush: {err:#}"),
441                    });
442                    tracer.trace_event(SmtpClientTraceEvent::Closed);
443                }
444                return Err(ClientError::FlushError {
445                    error: format!("{err:#}"),
446                });
447            }
448        }
449
450        let mut line = self.read_line(timeout_duration, command).await?;
451        tracing::trace!("recv<-{}: {line}", self.hostname);
452        let mut parsed = parse_response_line(&line)?;
453        let mut response_builder = ResponseBuilder::new(&parsed);
454
455        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
456        while !parsed.is_final {
457            line = self
458                .read_line(subsequent_line_timeout_duration, command)
459                .await?;
460            parsed = parse_response_line(&line)?;
461            response_builder
462                .add_line(&parsed)
463                .map_err(ClientError::MalformedResponseLine)?;
464        }
465
466        let response = response_builder.build(command.map(|cmd| cmd.encode().to_string()));
467
468        tracing::trace!("{}: {response:?}", self.hostname);
469
470        Ok(response)
471    }
472
473    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
474        self.write_command_request(command).await?;
475        self.read_response(Some(command), command.client_timeout(&self.timeouts))
476            .await
477    }
478
479    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
480    /// a Close event to the tracer if a timeout or IO error occurs.
481    /// If an error or timeout, occurs ensures that the socket will not be reused.
482    async fn write_all_with_timeout<F, G>(
483        &mut self,
484        timeout_duration: Duration,
485        bytes: &[u8],
486        make_timeout_err: F,
487        make_write_err: G,
488    ) -> Result<(), ClientError>
489    where
490        F: FnOnce() -> ClientError,
491        G: FnOnce(String) -> ClientError,
492    {
493        match self.socket.as_mut() {
494            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
495                Ok(Ok(response)) => Ok(response),
496                Ok(Err(err)) => {
497                    self.socket.take();
498                    if let Some(tracer) = &self.tracer {
499                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
500                            level: Level::ERROR,
501                            message: format!("Error during write: {err:#}"),
502                        });
503                        tracer.trace_event(SmtpClientTraceEvent::Closed);
504                    }
505                    Err(make_write_err(format!("{err:#}")))
506                }
507                Err(_) => {
508                    self.socket.take();
509                    if let Some(tracer) = &self.tracer {
510                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
511                            level: Level::ERROR,
512                            message: format!("Write Timeout after {timeout_duration:?}"),
513                        });
514                        tracer.trace_event(SmtpClientTraceEvent::Closed);
515                    }
516                    Err(make_timeout_err())
517                }
518            },
519            None => Err(make_write_err(
520                "the socket was closed in response to an earlier issue".to_string(),
521            )),
522        }
523    }
524
525    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
526        let total_timeout: Duration = commands
527            .iter()
528            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
529            .sum();
530
531        let mut lines: Vec<String> = vec![];
532        let mut all: Vec<u8> = vec![];
533        for cmd in commands {
534            let line = cmd.encode();
535            all.extend_from_slice(&line);
536            lines.push(line.to_string());
537        }
538        tracing::trace!(
539            "send->{}: (PIPELINE) {}",
540            self.hostname,
541            all.as_bstr().escape_bytes()
542        );
543        if self.socket.is_some() {
544            if let Some(tracer) = &self.tracer {
545                // Send the lines individually to the tracer, so that we
546                // don't break --terse mode
547                for line in &lines {
548                    WriteTracer::trace(tracer, line);
549                }
550            }
551        }
552        self.write_all_with_timeout(
553            total_timeout,
554            &all,
555            || ClientError::TimeOutRequest {
556                duration: total_timeout,
557                commands: commands.to_vec(),
558            },
559            |error| ClientError::WriteError {
560                error,
561                commands: commands.to_vec(),
562            },
563        )
564        .await
565    }
566
567    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
568        let line = command.encode();
569        tracing::trace!("send->{}: {line}", self.hostname);
570        if self.socket.is_some() {
571            if let Some(tracer) = &self.tracer {
572                WriteTracer::trace(tracer, &line.to_string());
573            }
574        }
575
576        let timeout_duration = command.client_timeout_request(&self.timeouts);
577        self.write_all_with_timeout(
578            timeout_duration,
579            &line,
580            || ClientError::TimeOutRequest {
581                duration: timeout_duration,
582                commands: vec![command.clone()],
583            },
584            |error| ClientError::WriteError {
585                error,
586                commands: vec![command.clone()],
587            },
588        )
589        .await
590    }
591
592    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
593        if self.socket.is_some() {
594            if let Some(tracer) = &self.tracer {
595                BinWriteTracer::trace(tracer, data);
596            }
597        }
598        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
599        self.write_all_with_timeout(
600            timeout_duration,
601            data,
602            || ClientError::TimeOutData,
603            |error| ClientError::WriteError {
604                error,
605                commands: vec![],
606            },
607        )
608        .await
609    }
610
611    /// Issue a series of commands, and return the responses to
612    /// those commands.
613    ///
614    /// If the server advertised the RFC 2920 PIPELINING extension,
615    /// the commands are written one after the other before waiting
616    /// to read any data, resulting in lower overall latency due
617    /// to round-trip-times.
618    ///
619    /// If PIPELINING is not available, each command is written
620    /// and the response read before attempting to write the next
621    /// command.
622    ///
623    /// The number of returned responses may be smaller than the
624    /// number of requested commands if there is an issue with
625    /// the network connection.
626    pub async fn pipeline_commands(
627        &mut self,
628        commands: Vec<Command>,
629    ) -> Vec<Result<Response, ClientError>> {
630        let mut results: Vec<Result<Response, ClientError>> = vec![];
631
632        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
633        if pipeline {
634            if let Err(err) = self.write_pipeline_request(&commands).await {
635                let err: ClientError = err;
636                results.push(Err(err.clone()));
637                while results.len() < commands.len() {
638                    // Synthesize failures for the remaining commands
639                    results.push(Err(err.clone()));
640                }
641                return results;
642            }
643
644            // Now read the responses effectively in a batch
645            for cmd in &commands {
646                results.push(
647                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
648                        .await,
649                );
650            }
651            return results;
652        }
653
654        for cmd in &commands {
655            if let Err(err) = self.write_command_request(cmd).await {
656                let err: ClientError = err;
657                results.push(Err(err.clone()));
658                while results.len() < commands.len() {
659                    // Synthesize failures for the remaining commands
660                    results.push(Err(err.clone()));
661                }
662                return results;
663            }
664            // Immediately request the response if the server
665            // doesn't support pipelining
666            results.push(
667                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
668                    .await,
669            );
670        }
671        results
672    }
673
674    pub async fn ehlo_lhlo(
675        &mut self,
676        ehlo_name: &str,
677        use_lmtp: bool,
678    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
679        if use_lmtp {
680            self.lhlo(ehlo_name).await
681        } else {
682            self.ehlo(ehlo_name).await
683        }
684    }
685
686    pub async fn lhlo(
687        &mut self,
688        ehlo_name: &str,
689    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
690        let response = self
691            .send_command(&Command::Lhlo(Domain::DomainName(
692                ehlo_name
693                    .parse::<DomainString>()
694                    .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
695            )))
696            .await?;
697        self.ehlo_common(response)
698    }
699
700    pub async fn ehlo(
701        &mut self,
702        ehlo_name: &str,
703    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
704        let response = self
705            .send_command(&Command::Ehlo(Domain::DomainName(
706                ehlo_name
707                    .parse::<DomainString>()
708                    .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
709            )))
710            .await?;
711        self.ehlo_common(response)
712    }
713
714    fn ehlo_common(
715        &mut self,
716        response: Response,
717    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
718        if response.code != 250 {
719            return Err(ClientError::Rejected(response));
720        }
721
722        let mut capabilities = HashMap::new();
723
724        for line in response.content.lines().skip(1) {
725            let mut fields = line.splitn(2, ' ');
726            if let Some(name) = fields.next() {
727                let param = fields.next().map(|s| s.to_string());
728                let cap = EsmtpCapability {
729                    name: name.to_string(),
730                    param,
731                };
732                capabilities.insert(name.to_ascii_uppercase(), cap);
733            }
734        }
735
736        self.capabilities = capabilities;
737        Ok(&self.capabilities)
738    }
739
740    pub async fn auth_plain(
741        &mut self,
742        username: &str,
743        password: Option<&str>,
744    ) -> Result<(), ClientError> {
745        // RFC 4616 says that the format is:
746        // [authzid] NUL authcid NUL passwd
747        let password = password.unwrap_or("");
748        let payload = format!("\x00{username}\x00{password}");
749        let payload = data_encoding::BASE64.encode(payload.as_bytes());
750
751        let response = self
752            .send_command(&Command::Auth {
753                sasl_mech: "PLAIN".to_string(),
754                initial_response: Some(payload),
755            })
756            .await?;
757
758        if response.code != 235 {
759            return Err(ClientError::Rejected(response));
760        }
761
762        Ok(())
763    }
764
765    /// Attempt TLS handshake.
766    /// Returns Err for IO errors.
767    /// On completion, return an option that will be:
768    /// * Some(handshake_error) - if the handshake failed
769    /// * None - if the handshake succeeded
770    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
771        let resp = self.send_command(&Command::StartTls).await?;
772        if resp.code != 220 {
773            return Err(ClientError::Rejected(resp));
774        }
775
776        let mut handshake_error = None;
777        let mut tls_info = TlsInformation::default();
778
779        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
780            || !options.dane_tlsa.is_empty()
781        {
782            let connector = options
783                .build_openssl_connector(&self.hostname)
784                .map_err(|error| ClientError::FailedToBuildConnector {
785                    error: error.to_string(),
786                })?;
787            let ssl = connector.into_ssl(self.hostname.as_str())?;
788
789            let (stream, dup_stream) = match self.socket.take() {
790                Some(s) => {
791                    let d = s.try_dup();
792                    (s, d)
793                }
794                None => return Err(ClientError::NotConnected),
795            };
796
797            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
798
799            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
800                handshake_error.replace(format!("{err:#}"));
801            }
802
803            tls_info.provider_name = "openssl".to_string();
804            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
805                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
806                None => String::new(),
807            };
808            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
809
810            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
811                tls_info.subject_name = subject_name(&cert);
812            }
813            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
814                if let Some(cert) = &authority.cert {
815                    tls_info.subject_name = subject_name(cert);
816                }
817            }
818
819            match (&handshake_error, dup_stream) {
820                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
821                    // Try falling back to clear text on the duplicate stream.
822                    // This is imperfect: in a failed validation scenario we will
823                    // end up trying to read binary data as a string and get a UTF-8
824                    // error if the peer thinks the session is encrypted.
825                    drop(ssl_stream);
826                    Box::new(dup_stream)
827                }
828                _ => Box::new(ssl_stream),
829            }
830        } else {
831            tls_info.provider_name = "rustls".to_string();
832            let connector = options.build_tls_connector().await.map_err(|error| {
833                ClientError::FailedToBuildConnector {
834                    error: error.to_string(),
835                }
836            })?;
837            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
838                Ok(ip) => ServerName::IpAddress(ip.into()),
839                Err(_) => ServerName::try_from(self.hostname.clone())
840                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
841            };
842
843            match connector
844                .connect(
845                    server_name,
846                    match self.socket.take() {
847                        Some(s) => s,
848                        None => return Err(ClientError::NotConnected),
849                    },
850                )
851                .into_fallible()
852                .await
853            {
854                Ok(stream) => {
855                    let (_, conn) = stream.get_ref();
856                    tls_info.cipher = match conn.negotiated_cipher_suite() {
857                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
858                        None => String::new(),
859                    };
860                    tls_info.protocol_version = match conn.protocol_version() {
861                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
862                        None => String::new(),
863                    };
864
865                    if let Some(certs) = conn.peer_certificates() {
866                        let peer_cert = &certs[0];
867                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
868                            tls_info.subject_name = subject_name(&cert);
869                        }
870                    }
871
872                    Box::new(stream)
873                }
874                Err((err, stream)) => {
875                    handshake_error.replace(format!("{err:#}"));
876                    stream
877                }
878            }
879        };
880
881        if let Some(tracer) = &self.tracer {
882            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
883                level: Level::INFO,
884                message: match &handshake_error {
885                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
886                    None => format!("STARTTLS handshake -> {tls_info:?}"),
887                },
888            });
889        }
890
891        self.socket.replace(stream);
892        Ok(match handshake_error {
893            Some(error) => TlsStatus::FailedHandshake(error),
894            None => TlsStatus::Info(tls_info),
895        })
896    }
897
898    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
899        &mut self,
900        sender: SENDER,
901        recipient: RECIP,
902        data: B,
903    ) -> Result<Response, ClientError> {
904        let recipient: ForwardPath = recipient.into();
905        let status = self
906            .send_mail_multi_recip(sender, vec![recipient], data)
907            .await?;
908        Ok(status.response)
909    }
910
911    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
912        &mut self,
913        sender: SENDER,
914        recipient_list: Vec<ForwardPath>,
915        data: B,
916    ) -> Result<BatchSendSuccess, ClientError> {
917        let sender = sender.into();
918
919        let data: &[u8] = data.as_ref();
920        let stuffed;
921
922        let data = match apply_dot_stuffing(data) {
923            Some(d) => {
924                stuffed = d;
925                &stuffed
926            }
927            None => data,
928        };
929
930        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
931        let envelope_is_8bit =
932            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
933
934        let mut mail_from_params = vec![];
935        if data_is_8bit {
936            if self.capabilities.contains_key("8BITMIME") {
937                mail_from_params.push(EsmtpParameter {
938                    name: "BODY".to_string(),
939                    value: Some("8BITMIME".to_string()),
940                });
941            } else if !self.ignore_8bit_checks {
942                return Err(ClientError::Rejected(Response {
943                    code: 554,
944                    command: None,
945                    enhanced_code: Some(EnhancedStatusCode {
946                        class: 5,
947                        subject: 6,
948                        detail: 3,
949                    }),
950                    content: "KumoMTA internal: DATA is 8bit, destination does \
951                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
952                        during reception is required"
953                        .to_string(),
954                }));
955            }
956        }
957
958        if envelope_is_8bit {
959            if self.capabilities.contains_key("SMTPUTF8") {
960                mail_from_params.push(EsmtpParameter {
961                    name: "SMTPUTF8".to_string(),
962                    value: None,
963                });
964            } else if !self.ignore_8bit_checks {
965                return Err(ClientError::Rejected(Response {
966                    code: 554,
967                    command: None,
968                    enhanced_code: Some(EnhancedStatusCode {
969                        class: 5,
970                        subject: 6,
971                        detail: 7,
972                    }),
973                    content: "KumoMTA internal: envelope is 8bit, destination does \
974                        not support SMTPUTF8."
975                        .to_string(),
976                }));
977            }
978        }
979
980        let mut commands = vec![];
981
982        // We want to avoid using RSET for the first message we send on
983        // a given connection, because postfix can run in a mode where
984        // it will not tolerate RSET because it considers it to be a "junk"
985        // command, and rejecting junk commands will cut down on its load
986        // when it is under stress; it is used as a load shedding approach.
987        // If we always RSET then we will never deliver to a site that is
988        // configured that way. If we take care to RSET only for subsequent
989        // sends, then we should get one message per connection through
990        // without being unfairly penalized for defensively RSETing.
991        let used_rset = self.use_rset;
992        if self.use_rset {
993            commands.push(Command::Rset);
994        }
995        commands.push(Command::MailFrom {
996            address: sender,
997            parameters: mail_from_params,
998        });
999
1000        for recipient in &recipient_list {
1001            commands.push(Command::RcptTo {
1002                address: recipient.clone(),
1003                parameters: vec![],
1004            });
1005        }
1006        commands.push(Command::Data);
1007
1008        // Assume that something might break below: if it does, we want
1009        // to ensure that we RSET the connection on the next go around.
1010        self.use_rset = true;
1011
1012        let mut responses = self.pipeline_commands(commands).await;
1013
1014        // This is a little awkward. We want to handle the RFC 2090 3.1 case
1015        // below, which requires deferring checking the actual response codes
1016        // until later, but we must also handle the case where we had a hard
1017        // transport error partway through pipelining.
1018        // So we set a flag for that case and will then "eagerly", wrt. the
1019        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
1020        // can propagate the correct error disposition up to the caller.
1021        let is_err = responses.iter().any(|r| r.is_err());
1022
1023        if used_rset {
1024            let rset_resp = responses.remove(0)?;
1025            if rset_resp.code != 250 {
1026                return Err(ClientError::Rejected(rset_resp));
1027            }
1028        }
1029
1030        let mail_resp = responses.remove(0)?;
1031        if is_err && mail_resp.code != 250 {
1032            return Err(ClientError::Rejected(mail_resp));
1033        }
1034
1035        let mut rcpt_responses = vec![];
1036        for _ in &recipient_list {
1037            rcpt_responses.push(responses.remove(0)?);
1038        }
1039
1040        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1041            return Err(ClientError::RejectedBatch(rcpt_responses));
1042        }
1043
1044        let data_resp = responses.remove(0)?;
1045        if is_err && data_resp.code != 354 {
1046            return Err(ClientError::Rejected(data_resp));
1047        }
1048
1049        if data_resp.code == 354
1050            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1051        {
1052            // RFC 2920 3.1:
1053            // the client cannot assume that the DATA command will be rejected
1054            // just because none of the RCPT TO commands worked.  If the DATA
1055            // command was properly rejected the client SMTP can just issue
1056            // RSET, but if the DATA command was accepted the client SMTP
1057            // should send a single dot.
1058
1059            // Send dummy data
1060            self.write_data_with_timeout(b".\r\n").await?;
1061            let data_dot = Command::DataDot;
1062            // wait for its response
1063            let _ = self
1064                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1065                .await?;
1066
1067            // Continue below: we will match one of the failure cases and
1068            // return a ClientError::Rejected from one of the earlier
1069            // commands
1070        }
1071
1072        if mail_resp.code != 250 {
1073            return Err(ClientError::Rejected(mail_resp));
1074        }
1075        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1076            if rcpt_responses.len() == 1 {
1077                return Err(ClientError::Rejected(
1078                    rcpt_responses.pop().expect("have at least one"),
1079                ));
1080            }
1081            return Err(ClientError::RejectedBatch(rcpt_responses));
1082        }
1083        if data_resp.code != 354 {
1084            return Err(ClientError::Rejected(data_resp));
1085        }
1086
1087        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1088
1089        tracing::trace!("message data is {} bytes", data.len());
1090
1091        self.write_data_with_timeout(data).await?;
1092
1093        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1094
1095        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1096
1097        self.write_data_with_timeout(marker.as_bytes()).await?;
1098
1099        let data_dot = Command::DataDot;
1100        let resp = self
1101            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1102            .await?;
1103        if resp.code != 250 {
1104            return Err(ClientError::Rejected(resp));
1105        }
1106
1107        // If everything went well, respect the user preference for speculatively
1108        // issuing an RSET next time around
1109        self.use_rset = self.enable_rset;
1110
1111        Ok(BatchSendSuccess {
1112            response: resp,
1113            rcpt_responses,
1114        })
1115    }
1116}
1117
1118#[derive(Debug)]
1119pub struct BatchSendSuccess {
1120    pub response: Response,
1121    pub rcpt_responses: Vec<Response>,
1122}
1123
1124#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1125pub enum TlsStatus {
1126    FailedHandshake(String),
1127    Info(TlsInformation),
1128}
1129
1130#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1131pub struct TlsInformation {
1132    pub cipher: String,
1133    pub protocol_version: String,
1134    pub subject_name: Vec<String>,
1135    pub provider_name: String,
1136}
1137
1138impl Drop for SmtpClient {
1139    fn drop(&mut self) {
1140        if let Some(tracer) = &self.tracer {
1141            if self.socket.is_some() {
1142                tracer.trace_event(SmtpClientTraceEvent::Closed);
1143            }
1144        }
1145    }
1146}
1147fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1148    if line.len() < 4 {
1149        return Err(ClientError::MalformedResponseLine(line.to_string()));
1150    }
1151
1152    match line.as_bytes()[3] {
1153        b' ' | b'-' => match line[0..3].parse::<u16>() {
1154            Ok(code) => Ok(ResponseLine {
1155                code,
1156                is_final: line.as_bytes()[3] == b' ',
1157                content: &line[4..],
1158            }),
1159            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1160        },
1161        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1162    }
1163}
1164
1165fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1166    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1167
1168    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1169        return None;
1170    }
1171
1172    let mut stuffed = vec![];
1173    if data.starts_with(b".") {
1174        stuffed.push(b'.');
1175    }
1176    let mut last_idx = 0;
1177    for i in LFDOT.find_iter(data) {
1178        stuffed.extend_from_slice(&data[last_idx..=i]);
1179        stuffed.push(b'.');
1180        last_idx = i + 1;
1181    }
1182    stuffed.extend_from_slice(&data[last_idx..]);
1183    Some(stuffed)
1184}
1185
1186/// Extracts the object=name pairs of the subject name from a cert.
1187/// eg:
1188/// ```no_run
1189/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1190/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1191/// ```
1192pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1193    let mut subject_name = vec![];
1194    for entry in cert.subject_name().entries() {
1195        if let Ok(obj) = entry.object().nid().short_name() {
1196            if let Ok(name) = entry.data().as_utf8() {
1197                subject_name.push(format!("{obj}={name}"));
1198            }
1199        }
1200    }
1201    subject_name
1202}
1203
1204#[cfg(test)]
1205mod test {
1206    use super::*;
1207    use crate::parser::{EnvelopeAddress, ReversePath};
1208
1209    #[test]
1210    fn test_stuffing() {
1211        assert_eq!(apply_dot_stuffing(b"foo"), None);
1212        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1213        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1214        assert_eq!(
1215            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1216            b"foo\n..bar\n...baz\n"
1217        );
1218    }
1219
1220    /*
1221    #[tokio::test]
1222    async fn test_against_sink() {
1223        use tokio::net::TcpStream;
1224        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1225        let mut client =
1226            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1227        dbg!(client.read_response(None).await).unwrap();
1228        dbg!(client.ehlo("localhost").await).unwrap();
1229        let insecure = true;
1230        dbg!(client.starttls(insecure).await).unwrap();
1231        let resp = client
1232            .send_mail(
1233                ReversePath::try_from("wez@mail.example.com").unwrap(),
1234                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1235                "Subject: hello\r\n\r\nwoot\r\n",
1236            )
1237            .await
1238            .unwrap();
1239        panic!("{resp:#?}");
1240    }
1241    */
1242
1243    #[test]
1244    fn response_line_parsing() {
1245        assert_eq!(
1246            parse_response_line("220 woot").unwrap(),
1247            ResponseLine {
1248                code: 220,
1249                is_final: true,
1250                content: "woot"
1251            }
1252        );
1253        assert_eq!(
1254            parse_response_line("220-woot").unwrap(),
1255            ResponseLine {
1256                code: 220,
1257                is_final: false,
1258                content: "woot"
1259            }
1260        );
1261
1262        assert!(matches!(
1263            parse_response_line("220_woot"),
1264            Err(ClientError::MalformedResponseLine(_))
1265        ));
1266        assert!(matches!(
1267            parse_response_line("not really"),
1268            Err(ClientError::MalformedResponseLine(_))
1269        ));
1270    }
1271
1272    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1273        let mut parsed = parse_response_line(lines[0])?;
1274        let mut b = ResponseBuilder::new(&parsed);
1275        for line in &lines[1..] {
1276            parsed = parse_response_line(line)?;
1277            b.add_line(&parsed)
1278                .map_err(ClientError::MalformedResponseLine)?;
1279        }
1280        assert!(parsed.is_final);
1281        Ok(b.build(None))
1282    }
1283
1284    #[test]
1285    fn multi_line_response() {
1286        assert_eq!(
1287            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1288            Response {
1289                code: 220,
1290                enhanced_code: None,
1291                content: "woot\nmore\ndone".to_string(),
1292                command: None
1293            }
1294        );
1295
1296        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1297        assert!(
1298            matches!(
1299                    res,
1300                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1301            ),
1302            "got error {res:?}"
1303        );
1304
1305        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1306        assert!(
1307            matches!(
1308                    res,
1309                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1310            ),
1311            "got error {res:?}"
1312        );
1313
1314        assert_eq!(
1315            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1316            Response {
1317                code: 220,
1318                enhanced_code: Some(EnhancedStatusCode {
1319                    class: 4,
1320                    subject: 1,
1321                    detail: 0
1322                }),
1323                content: "woot\nmore\ndone".to_string(),
1324                command: None
1325            }
1326        );
1327
1328        // Confirm that we strip the enhanced status code from each line
1329        assert_eq!(
1330            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1331            Response {
1332                code: 220,
1333                enhanced_code: Some(EnhancedStatusCode {
1334                    class: 4,
1335                    subject: 1,
1336                    detail: 0
1337                }),
1338                content: "woot\nmore\ndone".to_string(),
1339                command: None
1340            }
1341        );
1342
1343        // ... but only if the code matches that of the first line
1344        assert_eq!(
1345            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1346            Response {
1347                code: 220,
1348                enhanced_code: Some(EnhancedStatusCode {
1349                    class: 4,
1350                    subject: 1,
1351                    detail: 0
1352                }),
1353                content: "woot\nmore\n5.5.5 done".to_string(),
1354                command: None
1355            }
1356        );
1357    }
1358
1359    #[test]
1360    fn test_extract_hostname() {
1361        assert_eq!(extract_hostname("foo"), "foo");
1362        assert_eq!(extract_hostname("foo."), "foo");
1363        assert_eq!(extract_hostname("foo:25"), "foo");
1364        assert_eq!(extract_hostname("foo.:25"), "foo");
1365        assert_eq!(extract_hostname("[foo]:25"), "foo");
1366        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1367        assert_eq!(extract_hostname("[::1]:25"), "::1");
1368        assert_eq!(extract_hostname("::1:25"), "::1");
1369    }
1370
1371    #[test]
1372    fn test_format_error_command() {
1373        assert_eq!(
1374            format!(
1375                "{:#}",
1376                ClientError::TimeOutRequest {
1377                    commands: vec![Command::DataDot],
1378                    duration: Duration::from_secs(10),
1379                }
1380            ),
1381            "Timed Out after 10s writing cmd=."
1382        );
1383        assert_eq!(
1384            format!(
1385                "{:#}",
1386                ClientError::TimeOutResponse {
1387                    command: Some(Command::MailFrom {
1388                        address: {
1389                            let EnvelopeAddress::Path(p) =
1390                                EnvelopeAddress::parse("user@host").unwrap()
1391                            else {
1392                                panic!("expected Path")
1393                            };
1394                            ReversePath::Path(p)
1395                        },
1396                        parameters: vec![],
1397                    }),
1398                    duration: Duration::from_secs(10),
1399                }
1400            ),
1401            r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1402        );
1403    }
1404}