rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use openssl;
25pub use tokio_rustls;
26
27const MAX_LINE_LEN: usize = 4096;
28
29#[derive(Error, Debug, Clone)]
30pub enum ClientError {
31    #[error("response is not UTF8")]
32    Utf8(#[from] std::string::FromUtf8Error),
33    #[error("Malformed Response: {0}")]
34    MalformedResponseLine(String),
35    #[error("Response line is too long")]
36    ResponseTooLong,
37    #[error("Not connected")]
38    NotConnected,
39    #[error("Command rejected {0:?}")]
40    Rejected(Response),
41    #[error("Commands rejected {0:?}")]
42    RejectedBatch(Vec<Response>),
43    #[error("STARTTLS: {0} is not a valid DNS name")]
44    InvalidDnsName(String),
45    #[error("Invalid client certificate configured: {error:?}")]
46    FailedToBuildConnector { error: String },
47    #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
48    TimeOutResponse {
49        command: Option<Command>,
50        duration: Duration,
51    },
52    #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
53    TimeOutRequest {
54        commands: Vec<Command>,
55        duration: Duration,
56    },
57    #[error("Error {error} reading response to cmd={}", self.command_for_err())]
58    ReadError {
59        command: Option<Command>,
60        error: String,
61        partial: String,
62    },
63    #[error("Error {error} flushing send buffer")]
64    FlushError { error: String },
65    #[error("Error {error} writing {}", self.command_for_err())]
66    WriteError {
67        commands: Vec<Command>,
68        error: String,
69    },
70    #[error("Timed Out sending message payload data")]
71    TimeOutData,
72    #[error("SSL Error: {0}")]
73    SslErrorStack(#[from] openssl::error::ErrorStack),
74    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79    /// Returns the command(s) string suitable for passing into a Response
80    pub fn command(&self) -> Option<String> {
81        match self {
82            Self::TimeOutResponse {
83                command: Some(command),
84                ..
85            }
86            | Self::ReadError {
87                command: Some(command),
88                ..
89            } => Some(command.encode()),
90            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91                if !commands.is_empty() =>
92            {
93                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
94                Some(commands.join(""))
95            }
96            _ => None,
97        }
98    }
99
100    /// Returns the command(s) string suitable for embedding in the
101    /// overall error message
102    fn command_for_err(&self) -> String {
103        self.command()
104            .map(|cmd| cmd.replace("\r\n", ""))
105            .unwrap_or_else(|| "NONE".to_string())
106    }
107
108    /// If the error contents were likely caused by something
109    /// about the mostly recently attempted message, rather than
110    /// a transport issue, or a carry-over from a prior message
111    /// (eg: previous message was rejected and destination chose
112    /// to drop the connection, which we detect later in RSET
113    /// on the next message), then we return true.
114    /// The expectation is that the caller will transiently
115    /// fail the message for later retry.
116    /// If we return false then the caller might decide to
117    /// try that same message again more immediately on
118    /// a separate connection
119    pub fn was_due_to_message(&self) -> bool {
120        match self {
121            Self::Utf8(_)
122            | Self::MalformedResponseLine(_)
123            | Self::ResponseTooLong
124            | Self::NotConnected
125            | Self::InvalidDnsName(_)
126            | Self::FailedToBuildConnector { .. }
127            | Self::TimeOutResponse { .. }
128            | Self::TimeOutRequest { .. }
129            | Self::ReadError { .. }
130            | Self::FlushError { .. }
131            | Self::WriteError { .. }
132            | Self::TimeOutData
133            | Self::SslErrorStack(_)
134            | Self::NoUsableDaneTlsa { .. } => false,
135            Self::Rejected(response) => response.was_due_to_message(),
136            Self::RejectedBatch(responses) => match responses.len() {
137                1 => responses[0].was_due_to_message(),
138                _ => false,
139            },
140        }
141    }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct EsmtpCapability {
146    pub name: String,
147    pub param: Option<String>,
148}
149
150#[derive(Clone, Debug)]
151pub enum SmtpClientTraceEvent {
152    Closed,
153    Read(Vec<u8>),
154    Write(String),
155    Diagnostic {
156        level: tracing::Level,
157        message: String,
158    },
159}
160
161pub trait DeferredTracer {
162    fn trace(&self) -> SmtpClientTraceEvent;
163}
164
165pub trait SmtpClientTracer: std::fmt::Debug {
166    fn trace_event(&self, event: SmtpClientTraceEvent);
167    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
168}
169
170// helper to avoid making a second copy of every write buffer
171struct WriteTracer<'a> {
172    data: &'a str,
173}
174impl DeferredTracer for WriteTracer<'_> {
175    fn trace(&self) -> SmtpClientTraceEvent {
176        SmtpClientTraceEvent::Write(self.data.to_string())
177    }
178}
179impl<'a> WriteTracer<'a> {
180    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
181        tracer.lazy_trace(&Self { data });
182    }
183}
184
185struct BinWriteTracer<'a> {
186    data: &'a [u8],
187}
188impl DeferredTracer for BinWriteTracer<'_> {
189    fn trace(&self) -> SmtpClientTraceEvent {
190        let data = String::from_utf8_lossy(self.data).to_string();
191        SmtpClientTraceEvent::Write(data)
192    }
193}
194impl<'a> BinWriteTracer<'a> {
195    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
196        tracer.lazy_trace(&Self { data });
197    }
198}
199
200// A little bit of gymnastics to avoid making a second
201// copy of every read buffer
202struct ReadTracer<'a> {
203    data: &'a [u8],
204}
205impl DeferredTracer for ReadTracer<'_> {
206    fn trace(&self) -> SmtpClientTraceEvent {
207        SmtpClientTraceEvent::Read(self.data.to_vec())
208    }
209}
210
211#[derive(Debug)]
212pub struct SmtpClient {
213    socket: Option<BoxedAsyncReadAndWrite>,
214    hostname: String,
215    capabilities: HashMap<String, EsmtpCapability>,
216    read_buffer: Vec<u8>,
217    timeouts: SmtpClientTimeouts,
218    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
219    use_rset: bool,
220    enable_rset: bool,
221    enable_pipelining: bool,
222    ignore_8bit_checks: bool,
223}
224
225fn extract_hostname(hostname: &str) -> &str {
226    // Just the hostname, without any :port
227    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
228    let hostname = if fields.len() == 2 {
229        fields[1]
230    } else {
231        hostname
232    };
233
234    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
235        &hostname[1..hostname.len() - 1]
236    } else {
237        hostname
238    };
239
240    // Remove any trailing FQDN dot
241    hostname.strip_suffix(".").unwrap_or(hostname)
242}
243
244impl SmtpClient {
245    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
246        addr: A,
247        timeouts: SmtpClientTimeouts,
248    ) -> std::io::Result<Self> {
249        let stream = TcpStream::connect(addr.clone()).await?;
250        // No need for Nagle with SMTP request/response
251        stream.set_nodelay(true)?;
252        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
253    }
254
255    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
256        stream: S,
257        peer_hostname: H,
258        timeouts: SmtpClientTimeouts,
259    ) -> Self {
260        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
261
262        Self {
263            socket: Some(Box::new(stream)),
264            hostname,
265            capabilities: HashMap::new(),
266            read_buffer: Vec::with_capacity(1024),
267            timeouts,
268            tracer: None,
269            use_rset: false,
270            enable_rset: false,
271            enable_pipelining: false,
272            ignore_8bit_checks: false,
273        }
274    }
275
276    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
277    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
278        self.ignore_8bit_checks = enable;
279    }
280
281    pub fn is_connected(&self) -> bool {
282        self.socket.is_some()
283    }
284
285    pub fn set_enable_rset(&mut self, enable: bool) {
286        self.enable_rset = enable;
287    }
288
289    pub fn set_enable_pipelining(&mut self, enable: bool) {
290        self.enable_pipelining = enable;
291    }
292
293    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
294        self.tracer.replace(tracer);
295    }
296
297    pub fn timeouts(&self) -> &SmtpClientTimeouts {
298        &self.timeouts
299    }
300
301    async fn read_line(
302        &mut self,
303        timeout_duration: Duration,
304        cmd: Option<&Command>,
305    ) -> Result<String, ClientError> {
306        let mut too_long = false;
307        loop {
308            let mut iter = self.read_buffer.iter().enumerate();
309            while let Some((i, &b)) = iter.next() {
310                if b != b'\r' {
311                    continue;
312                }
313                if let Some((_, b'\n')) = iter.next() {
314                    if too_long {
315                        self.read_buffer.drain(0..i + 2);
316
317                        if let Some(tracer) = &self.tracer {
318                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
319                                level: Level::ERROR,
320                                message: "Response too long".to_string(),
321                            });
322                        }
323
324                        return Err(ClientError::ResponseTooLong);
325                    }
326
327                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
328                    self.read_buffer.drain(0..i + 2);
329                    return Ok(line?);
330                }
331            }
332            if self.read_buffer.len() > MAX_LINE_LEN {
333                self.read_buffer.clear();
334                too_long = true;
335            }
336
337            // Didn't find a complete line, fill up the rest of the buffer
338            let mut data = [0u8; MAX_LINE_LEN];
339            let size = match self.socket.as_mut() {
340                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
341                    Ok(Ok(size)) => size,
342                    Ok(Err(err)) => {
343                        self.socket.take();
344                        if let Some(tracer) = &self.tracer {
345                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
346                                level: Level::ERROR,
347                                message: format!("Error during read: {err:#}"),
348                            });
349                            tracer.trace_event(SmtpClientTraceEvent::Closed);
350                        }
351                        return Err(ClientError::ReadError {
352                            command: cmd.cloned(),
353                            error: format!("{err:#}"),
354                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
355                        });
356                    }
357                    Err(_) => {
358                        // If we're using a zero timeout we're probing
359                        // for liveness, so we don't want to treat this
360                        // timeout as a fatal error that closes the
361                        // connection
362                        if timeout_duration != Duration::ZERO {
363                            self.socket.take();
364                        }
365                        if let Some(tracer) = &self.tracer {
366                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
367                                level: Level::ERROR,
368                                message: format!("Read Timeout after {timeout_duration:?}"),
369                            });
370                            if self.socket.is_none() {
371                                tracer.trace_event(SmtpClientTraceEvent::Closed);
372                            }
373                        }
374                        return Err(ClientError::TimeOutResponse {
375                            command: cmd.cloned(),
376                            duration: timeout_duration,
377                        });
378                    }
379                },
380                None => {
381                    return Err(ClientError::ReadError {
382                        command: cmd.cloned(),
383                        error: "the socket was closed in response to an earlier issue".to_string(),
384                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
385                    });
386                }
387            };
388            if size == 0 {
389                self.socket.take();
390                if let Some(tracer) = &self.tracer {
391                    tracer.trace_event(SmtpClientTraceEvent::Closed);
392                }
393                return Err(ClientError::ReadError {
394                    command: cmd.cloned(),
395                    error: "Connection closed by peer".to_string(),
396                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
397                });
398            }
399            if let Some(tracer) = &self.tracer {
400                tracer.lazy_trace(&ReadTracer {
401                    data: &data[0..size],
402                });
403            }
404            self.read_buffer.extend_from_slice(&data[0..size]);
405        }
406    }
407
408    /// Check to see if there is either a unilateral response
409    /// or an error condition immediately available on the socket.
410    ///
411    /// Intended to be used after the connection has idled to see if
412    /// the peer has opted to close the connection, either gracefully
413    /// by sending what should be a 421, or abruptly by simply snipping
414    /// the connection.
415    pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
416        match self.read_response(None, Duration::ZERO).await {
417            // Unilateral response was received
418            Ok(response) => Ok(Some(response)),
419            // Nothing immediately available: this should hopefully be the common case
420            Err(ClientError::TimeOutResponse { .. }) => Ok(None),
421            // An actual error
422            Err(err) => Err(err),
423        }
424    }
425
426    pub async fn read_response(
427        &mut self,
428        command: Option<&Command>,
429        timeout_duration: Duration,
430    ) -> Result<Response, ClientError> {
431        if let Some(sock) = self.socket.as_mut() {
432            if let Err(err) = sock.flush().await {
433                self.socket.take();
434                if let Some(tracer) = &self.tracer {
435                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
436                        level: Level::ERROR,
437                        message: format!("Error during flush: {err:#}"),
438                    });
439                    tracer.trace_event(SmtpClientTraceEvent::Closed);
440                }
441                return Err(ClientError::FlushError {
442                    error: format!("{err:#}"),
443                });
444            }
445        }
446
447        let mut line = self.read_line(timeout_duration, command).await?;
448        tracing::trace!("recv<-{}: {line}", self.hostname);
449        let mut parsed = parse_response_line(&line)?;
450        let mut response_builder = ResponseBuilder::new(&parsed);
451
452        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
453        while !parsed.is_final {
454            line = self
455                .read_line(subsequent_line_timeout_duration, command)
456                .await?;
457            parsed = parse_response_line(&line)?;
458            response_builder
459                .add_line(&parsed)
460                .map_err(ClientError::MalformedResponseLine)?;
461        }
462
463        let response = response_builder.build(command.map(|cmd| cmd.encode()));
464
465        tracing::trace!("{}: {response:?}", self.hostname);
466
467        Ok(response)
468    }
469
470    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
471        self.write_command_request(command).await?;
472        self.read_response(Some(command), command.client_timeout(&self.timeouts))
473            .await
474    }
475
476    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
477    /// a Close event to the tracer if a timeout or IO error occurs.
478    /// If an error or timeout, occurs ensures that the socket will not be reused.
479    async fn write_all_with_timeout<F, G>(
480        &mut self,
481        timeout_duration: Duration,
482        bytes: &[u8],
483        make_timeout_err: F,
484        make_write_err: G,
485    ) -> Result<(), ClientError>
486    where
487        F: FnOnce() -> ClientError,
488        G: FnOnce(String) -> ClientError,
489    {
490        match self.socket.as_mut() {
491            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
492                Ok(Ok(response)) => Ok(response),
493                Ok(Err(err)) => {
494                    self.socket.take();
495                    if let Some(tracer) = &self.tracer {
496                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
497                            level: Level::ERROR,
498                            message: format!("Error during write: {err:#}"),
499                        });
500                        tracer.trace_event(SmtpClientTraceEvent::Closed);
501                    }
502                    Err(make_write_err(format!("{err:#}")))
503                }
504                Err(_) => {
505                    self.socket.take();
506                    if let Some(tracer) = &self.tracer {
507                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
508                            level: Level::ERROR,
509                            message: format!("Write Timeout after {timeout_duration:?}"),
510                        });
511                        tracer.trace_event(SmtpClientTraceEvent::Closed);
512                    }
513                    Err(make_timeout_err())
514                }
515            },
516            None => Err(make_write_err(
517                "the socket was closed in response to an earlier issue".to_string(),
518            )),
519        }
520    }
521
522    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
523        let total_timeout: Duration = commands
524            .iter()
525            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
526            .sum();
527
528        let mut lines: Vec<String> = vec![];
529        let mut all = String::new();
530        for cmd in commands {
531            let line = cmd.encode();
532            all.push_str(&line);
533            lines.push(line);
534        }
535        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
536        if self.socket.is_some() {
537            if let Some(tracer) = &self.tracer {
538                // Send the lines individually to the tracer, so that we
539                // don't break --terse mode
540                for line in lines {
541                    WriteTracer::trace(tracer, &line);
542                }
543            }
544        }
545        self.write_all_with_timeout(
546            total_timeout,
547            all.as_bytes(),
548            || ClientError::TimeOutRequest {
549                duration: total_timeout,
550                commands: commands.to_vec(),
551            },
552            |error| ClientError::WriteError {
553                error,
554                commands: commands.to_vec(),
555            },
556        )
557        .await
558    }
559
560    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
561        let line = command.encode();
562        tracing::trace!("send->{}: {line}", self.hostname);
563        if self.socket.is_some() {
564            if let Some(tracer) = &self.tracer {
565                WriteTracer::trace(tracer, &line);
566            }
567        }
568
569        let timeout_duration = command.client_timeout_request(&self.timeouts);
570        self.write_all_with_timeout(
571            timeout_duration,
572            line.as_bytes(),
573            || ClientError::TimeOutRequest {
574                duration: timeout_duration,
575                commands: vec![command.clone()],
576            },
577            |error| ClientError::WriteError {
578                error,
579                commands: vec![command.clone()],
580            },
581        )
582        .await
583    }
584
585    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
586        if self.socket.is_some() {
587            if let Some(tracer) = &self.tracer {
588                BinWriteTracer::trace(tracer, data);
589            }
590        }
591        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
592        self.write_all_with_timeout(
593            timeout_duration,
594            data,
595            || ClientError::TimeOutData,
596            |error| ClientError::WriteError {
597                error,
598                commands: vec![],
599            },
600        )
601        .await
602    }
603
604    /// Issue a series of commands, and return the responses to
605    /// those commands.
606    ///
607    /// If the server advertised the RFC 2920 PIPELINING extension,
608    /// the commands are written one after the other before waiting
609    /// to read any data, resulting in lower overall latency due
610    /// to round-trip-times.
611    ///
612    /// If PIPELINING is not available, each command is written
613    /// and the response read before attempting to write the next
614    /// command.
615    ///
616    /// The number of returned responses may be smaller than the
617    /// number of requested commands if there is an issue with
618    /// the network connection.
619    pub async fn pipeline_commands(
620        &mut self,
621        commands: Vec<Command>,
622    ) -> Vec<Result<Response, ClientError>> {
623        let mut results: Vec<Result<Response, ClientError>> = vec![];
624
625        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
626        if pipeline {
627            if let Err(err) = self.write_pipeline_request(&commands).await {
628                let err: ClientError = err;
629                results.push(Err(err.clone()));
630                while results.len() < commands.len() {
631                    // Synthesize failures for the remaining commands
632                    results.push(Err(err.clone()));
633                }
634                return results;
635            }
636
637            // Now read the responses effectively in a batch
638            for cmd in &commands {
639                results.push(
640                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
641                        .await,
642                );
643            }
644            return results;
645        }
646
647        for cmd in &commands {
648            if let Err(err) = self.write_command_request(cmd).await {
649                let err: ClientError = err;
650                results.push(Err(err.clone()));
651                while results.len() < commands.len() {
652                    // Synthesize failures for the remaining commands
653                    results.push(Err(err.clone()));
654                }
655                return results;
656            }
657            // Immediately request the response if the server
658            // doesn't support pipelining
659            results.push(
660                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
661                    .await,
662            );
663        }
664        results
665    }
666
667    pub async fn ehlo_lhlo(
668        &mut self,
669        ehlo_name: &str,
670        use_lmtp: bool,
671    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
672        if use_lmtp {
673            self.lhlo(ehlo_name).await
674        } else {
675            self.ehlo(ehlo_name).await
676        }
677    }
678
679    pub async fn lhlo(
680        &mut self,
681        ehlo_name: &str,
682    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
683        let response = self
684            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
685            .await?;
686        self.ehlo_common(response)
687    }
688
689    pub async fn ehlo(
690        &mut self,
691        ehlo_name: &str,
692    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
693        let response = self
694            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
695            .await?;
696        self.ehlo_common(response)
697    }
698
699    fn ehlo_common(
700        &mut self,
701        response: Response,
702    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
703        if response.code != 250 {
704            return Err(ClientError::Rejected(response));
705        }
706
707        let mut capabilities = HashMap::new();
708
709        for line in response.content.lines().skip(1) {
710            let mut fields = line.splitn(2, ' ');
711            if let Some(name) = fields.next() {
712                let param = fields.next().map(|s| s.to_string());
713                let cap = EsmtpCapability {
714                    name: name.to_string(),
715                    param,
716                };
717                capabilities.insert(name.to_ascii_uppercase(), cap);
718            }
719        }
720
721        self.capabilities = capabilities;
722        Ok(&self.capabilities)
723    }
724
725    pub async fn auth_plain(
726        &mut self,
727        username: &str,
728        password: Option<&str>,
729    ) -> Result<(), ClientError> {
730        // RFC 4616 says that the format is:
731        // [authzid] NUL authcid NUL passwd
732        let password = password.unwrap_or("");
733        let payload = format!("\x00{username}\x00{password}");
734        let payload = data_encoding::BASE64.encode(payload.as_bytes());
735
736        let response = self
737            .send_command(&Command::Auth {
738                sasl_mech: "PLAIN".to_string(),
739                initial_response: Some(payload),
740            })
741            .await?;
742
743        if response.code != 235 {
744            return Err(ClientError::Rejected(response));
745        }
746
747        Ok(())
748    }
749
750    /// Attempt TLS handshake.
751    /// Returns Err for IO errors.
752    /// On completion, return an option that will be:
753    /// * Some(handshake_error) - if the handshake failed
754    /// * None - if the handshake succeeded
755    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
756        let resp = self.send_command(&Command::StartTls).await?;
757        if resp.code != 220 {
758            return Err(ClientError::Rejected(resp));
759        }
760
761        let mut handshake_error = None;
762        let mut tls_info = TlsInformation::default();
763
764        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
765            || !options.dane_tlsa.is_empty()
766        {
767            let connector = options
768                .build_openssl_connector(&self.hostname)
769                .map_err(|error| ClientError::FailedToBuildConnector {
770                    error: error.to_string(),
771                })?;
772            let ssl = connector.into_ssl(self.hostname.as_str())?;
773
774            let (stream, dup_stream) = match self.socket.take() {
775                Some(s) => {
776                    let d = s.try_dup();
777                    (s, d)
778                }
779                None => return Err(ClientError::NotConnected),
780            };
781
782            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
783
784            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
785                handshake_error.replace(format!("{err:#}"));
786            }
787
788            tls_info.provider_name = "openssl".to_string();
789            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
790                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
791                None => String::new(),
792            };
793            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
794
795            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
796                tls_info.subject_name = subject_name(&cert);
797            }
798            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
799                if let Some(cert) = &authority.cert {
800                    tls_info.subject_name = subject_name(cert);
801                }
802            }
803
804            match (&handshake_error, dup_stream) {
805                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
806                    // Try falling back to clear text on the duplicate stream.
807                    // This is imperfect: in a failed validation scenario we will
808                    // end up trying to read binary data as a string and get a UTF-8
809                    // error if the peer thinks the session is encrypted.
810                    drop(ssl_stream);
811                    Box::new(dup_stream)
812                }
813                _ => Box::new(ssl_stream),
814            }
815        } else {
816            tls_info.provider_name = "rustls".to_string();
817            let connector = options.build_tls_connector().await.map_err(|error| {
818                ClientError::FailedToBuildConnector {
819                    error: error.to_string(),
820                }
821            })?;
822            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
823                Ok(ip) => ServerName::IpAddress(ip.into()),
824                Err(_) => ServerName::try_from(self.hostname.clone())
825                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
826            };
827
828            match connector
829                .connect(
830                    server_name,
831                    match self.socket.take() {
832                        Some(s) => s,
833                        None => return Err(ClientError::NotConnected),
834                    },
835                )
836                .into_fallible()
837                .await
838            {
839                Ok(stream) => {
840                    let (_, conn) = stream.get_ref();
841                    tls_info.cipher = match conn.negotiated_cipher_suite() {
842                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
843                        None => String::new(),
844                    };
845                    tls_info.protocol_version = match conn.protocol_version() {
846                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
847                        None => String::new(),
848                    };
849
850                    if let Some(certs) = conn.peer_certificates() {
851                        let peer_cert = &certs[0];
852                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
853                            tls_info.subject_name = subject_name(&cert);
854                        }
855                    }
856
857                    Box::new(stream)
858                }
859                Err((err, stream)) => {
860                    handshake_error.replace(format!("{err:#}"));
861                    stream
862                }
863            }
864        };
865
866        if let Some(tracer) = &self.tracer {
867            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
868                level: Level::INFO,
869                message: match &handshake_error {
870                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
871                    None => format!("STARTTLS handshake -> {tls_info:?}"),
872                },
873            });
874        }
875
876        self.socket.replace(stream);
877        Ok(match handshake_error {
878            Some(error) => TlsStatus::FailedHandshake(error),
879            None => TlsStatus::Info(tls_info),
880        })
881    }
882
883    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
884        &mut self,
885        sender: SENDER,
886        recipient: RECIP,
887        data: B,
888    ) -> Result<Response, ClientError> {
889        let recipient: ForwardPath = recipient.into();
890        let status = self
891            .send_mail_multi_recip(sender, vec![recipient], data)
892            .await?;
893        Ok(status.response)
894    }
895
896    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
897        &mut self,
898        sender: SENDER,
899        recipient_list: Vec<ForwardPath>,
900        data: B,
901    ) -> Result<BatchSendSuccess, ClientError> {
902        let sender = sender.into();
903
904        let data: &[u8] = data.as_ref();
905        let stuffed;
906
907        let data = match apply_dot_stuffing(data) {
908            Some(d) => {
909                stuffed = d;
910                &stuffed
911            }
912            None => data,
913        };
914
915        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
916        let envelope_is_8bit =
917            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
918
919        let mut mail_from_params = vec![];
920        if data_is_8bit {
921            if self.capabilities.contains_key("8BITMIME") {
922                mail_from_params.push(EsmtpParameter {
923                    name: "BODY".to_string(),
924                    value: Some("8BITMIME".to_string()),
925                });
926            } else if !self.ignore_8bit_checks {
927                return Err(ClientError::Rejected(Response {
928                    code: 554,
929                    command: None,
930                    enhanced_code: Some(EnhancedStatusCode {
931                        class: 5,
932                        subject: 6,
933                        detail: 3,
934                    }),
935                    content: "KumoMTA internal: DATA is 8bit, destination does \
936                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
937                        during reception is required"
938                        .to_string(),
939                }));
940            }
941        }
942
943        if envelope_is_8bit {
944            if self.capabilities.contains_key("SMTPUTF8") {
945                mail_from_params.push(EsmtpParameter {
946                    name: "SMTPUTF8".to_string(),
947                    value: None,
948                });
949            } else if !self.ignore_8bit_checks {
950                return Err(ClientError::Rejected(Response {
951                    code: 554,
952                    command: None,
953                    enhanced_code: Some(EnhancedStatusCode {
954                        class: 5,
955                        subject: 6,
956                        detail: 7,
957                    }),
958                    content: "KumoMTA internal: envelope is 8bit, destination does \
959                        not support SMTPUTF8."
960                        .to_string(),
961                }));
962            }
963        }
964
965        let mut commands = vec![];
966
967        // We want to avoid using RSET for the first message we send on
968        // a given connection, because postfix can run in a mode where
969        // it will not tolerate RSET because it considers it to be a "junk"
970        // command, and rejecting junk commands will cut down on its load
971        // when it is under stress; it is used as a load shedding approach.
972        // If we always RSET then we will never deliver to a site that is
973        // configured that way. If we take care to RSET only for subsequent
974        // sends, then we should get one message per connection through
975        // without being unfairly penalized for defensively RSETing.
976        let used_rset = self.use_rset;
977        if self.use_rset {
978            commands.push(Command::Rset);
979        }
980        commands.push(Command::MailFrom {
981            address: sender,
982            parameters: mail_from_params,
983        });
984
985        for recipient in &recipient_list {
986            commands.push(Command::RcptTo {
987                address: recipient.clone(),
988                parameters: vec![],
989            });
990        }
991        commands.push(Command::Data);
992
993        // Assume that something might break below: if it does, we want
994        // to ensure that we RSET the connection on the next go around.
995        self.use_rset = true;
996
997        let mut responses = self.pipeline_commands(commands).await;
998
999        // This is a little awkward. We want to handle the RFC 2090 3.1 case
1000        // below, which requires deferring checking the actual response codes
1001        // until later, but we must also handle the case where we had a hard
1002        // transport error partway through pipelining.
1003        // So we set a flag for that case and will then "eagerly", wrt. the
1004        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
1005        // can propagate the correct error disposition up to the caller.
1006        let is_err = responses.iter().any(|r| r.is_err());
1007
1008        if used_rset {
1009            let rset_resp = responses.remove(0)?;
1010            if rset_resp.code != 250 {
1011                return Err(ClientError::Rejected(rset_resp));
1012            }
1013        }
1014
1015        let mail_resp = responses.remove(0)?;
1016        if is_err && mail_resp.code != 250 {
1017            return Err(ClientError::Rejected(mail_resp));
1018        }
1019
1020        let mut rcpt_responses = vec![];
1021        for _ in &recipient_list {
1022            rcpt_responses.push(responses.remove(0)?);
1023        }
1024
1025        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1026            return Err(ClientError::RejectedBatch(rcpt_responses));
1027        }
1028
1029        let data_resp = responses.remove(0)?;
1030        if is_err && data_resp.code != 354 {
1031            return Err(ClientError::Rejected(data_resp));
1032        }
1033
1034        if data_resp.code == 354
1035            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1036        {
1037            // RFC 2920 3.1:
1038            // the client cannot assume that the DATA command will be rejected
1039            // just because none of the RCPT TO commands worked.  If the DATA
1040            // command was properly rejected the client SMTP can just issue
1041            // RSET, but if the DATA command was accepted the client SMTP
1042            // should send a single dot.
1043
1044            // Send dummy data
1045            self.write_data_with_timeout(b".\r\n").await?;
1046            let data_dot = Command::DataDot;
1047            // wait for its response
1048            let _ = self
1049                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1050                .await?;
1051
1052            // Continue below: we will match one of the failure cases and
1053            // return a ClientError::Rejected from one of the earlier
1054            // commands
1055        }
1056
1057        if mail_resp.code != 250 {
1058            return Err(ClientError::Rejected(mail_resp));
1059        }
1060        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1061            if rcpt_responses.len() == 1 {
1062                return Err(ClientError::Rejected(
1063                    rcpt_responses.pop().expect("have at least one"),
1064                ));
1065            }
1066            return Err(ClientError::RejectedBatch(rcpt_responses));
1067        }
1068        if data_resp.code != 354 {
1069            return Err(ClientError::Rejected(data_resp));
1070        }
1071
1072        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1073
1074        tracing::trace!("message data is {} bytes", data.len());
1075
1076        self.write_data_with_timeout(data).await?;
1077
1078        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1079
1080        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1081
1082        self.write_data_with_timeout(marker.as_bytes()).await?;
1083
1084        let data_dot = Command::DataDot;
1085        let resp = self
1086            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1087            .await?;
1088        if resp.code != 250 {
1089            return Err(ClientError::Rejected(resp));
1090        }
1091
1092        // If everything went well, respect the user preference for speculatively
1093        // issuing an RSET next time around
1094        self.use_rset = self.enable_rset;
1095
1096        Ok(BatchSendSuccess {
1097            response: resp,
1098            rcpt_responses,
1099        })
1100    }
1101}
1102
1103#[derive(Debug)]
1104pub struct BatchSendSuccess {
1105    pub response: Response,
1106    pub rcpt_responses: Vec<Response>,
1107}
1108
1109#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1110pub enum TlsStatus {
1111    FailedHandshake(String),
1112    Info(TlsInformation),
1113}
1114
1115#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1116pub struct TlsInformation {
1117    pub cipher: String,
1118    pub protocol_version: String,
1119    pub subject_name: Vec<String>,
1120    pub provider_name: String,
1121}
1122
1123impl Drop for SmtpClient {
1124    fn drop(&mut self) {
1125        if let Some(tracer) = &self.tracer {
1126            if self.socket.is_some() {
1127                tracer.trace_event(SmtpClientTraceEvent::Closed);
1128            }
1129        }
1130    }
1131}
1132fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1133    if line.len() < 4 {
1134        return Err(ClientError::MalformedResponseLine(line.to_string()));
1135    }
1136
1137    match line.as_bytes()[3] {
1138        b' ' | b'-' => match line[0..3].parse::<u16>() {
1139            Ok(code) => Ok(ResponseLine {
1140                code,
1141                is_final: line.as_bytes()[3] == b' ',
1142                content: &line[4..],
1143            }),
1144            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1145        },
1146        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1147    }
1148}
1149
1150fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1151    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1152
1153    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1154        return None;
1155    }
1156
1157    let mut stuffed = vec![];
1158    if data.starts_with(b".") {
1159        stuffed.push(b'.');
1160    }
1161    let mut last_idx = 0;
1162    for i in LFDOT.find_iter(data) {
1163        stuffed.extend_from_slice(&data[last_idx..=i]);
1164        stuffed.push(b'.');
1165        last_idx = i + 1;
1166    }
1167    stuffed.extend_from_slice(&data[last_idx..]);
1168    Some(stuffed)
1169}
1170
1171/// Extracts the object=name pairs of the subject name from a cert.
1172/// eg:
1173/// ```no_run
1174/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1175/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1176/// ```
1177pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1178    let mut subject_name = vec![];
1179    for entry in cert.subject_name().entries() {
1180        if let Ok(obj) = entry.object().nid().short_name() {
1181            if let Ok(name) = entry.data().as_utf8() {
1182                subject_name.push(format!("{obj}={name}"));
1183            }
1184        }
1185    }
1186    subject_name
1187}
1188
1189#[cfg(test)]
1190mod test {
1191    use super::*;
1192    use crate::{MailPath, Mailbox};
1193
1194    #[test]
1195    fn test_stuffing() {
1196        assert_eq!(apply_dot_stuffing(b"foo"), None);
1197        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1198        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1199        assert_eq!(
1200            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1201            b"foo\n..bar\n...baz\n"
1202        );
1203    }
1204
1205    /*
1206    #[tokio::test]
1207    async fn test_against_sink() {
1208        use tokio::net::TcpStream;
1209        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1210        let mut client =
1211            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1212        dbg!(client.read_response(None).await).unwrap();
1213        dbg!(client.ehlo("localhost").await).unwrap();
1214        let insecure = true;
1215        dbg!(client.starttls(insecure).await).unwrap();
1216        let resp = client
1217            .send_mail(
1218                ReversePath::try_from("wez@mail.example.com").unwrap(),
1219                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1220                "Subject: hello\r\n\r\nwoot\r\n",
1221            )
1222            .await
1223            .unwrap();
1224        panic!("{resp:#?}");
1225    }
1226    */
1227
1228    #[test]
1229    fn response_line_parsing() {
1230        assert_eq!(
1231            parse_response_line("220 woot").unwrap(),
1232            ResponseLine {
1233                code: 220,
1234                is_final: true,
1235                content: "woot"
1236            }
1237        );
1238        assert_eq!(
1239            parse_response_line("220-woot").unwrap(),
1240            ResponseLine {
1241                code: 220,
1242                is_final: false,
1243                content: "woot"
1244            }
1245        );
1246
1247        assert!(matches!(
1248            parse_response_line("220_woot"),
1249            Err(ClientError::MalformedResponseLine(_))
1250        ));
1251        assert!(matches!(
1252            parse_response_line("not really"),
1253            Err(ClientError::MalformedResponseLine(_))
1254        ));
1255    }
1256
1257    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1258        let mut parsed = parse_response_line(lines[0])?;
1259        let mut b = ResponseBuilder::new(&parsed);
1260        for line in &lines[1..] {
1261            parsed = parse_response_line(line)?;
1262            b.add_line(&parsed)
1263                .map_err(ClientError::MalformedResponseLine)?;
1264        }
1265        assert!(parsed.is_final);
1266        Ok(b.build(None))
1267    }
1268
1269    #[test]
1270    fn multi_line_response() {
1271        assert_eq!(
1272            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1273            Response {
1274                code: 220,
1275                enhanced_code: None,
1276                content: "woot\nmore\ndone".to_string(),
1277                command: None
1278            }
1279        );
1280
1281        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1282        assert!(
1283            matches!(
1284                    res,
1285                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1286            ),
1287            "got error {res:?}"
1288        );
1289
1290        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1291        assert!(
1292            matches!(
1293                    res,
1294                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1295            ),
1296            "got error {res:?}"
1297        );
1298
1299        assert_eq!(
1300            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1301            Response {
1302                code: 220,
1303                enhanced_code: Some(EnhancedStatusCode {
1304                    class: 4,
1305                    subject: 1,
1306                    detail: 0
1307                }),
1308                content: "woot\nmore\ndone".to_string(),
1309                command: None
1310            }
1311        );
1312
1313        // Confirm that we strip the enhanced status code from each line
1314        assert_eq!(
1315            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1316            Response {
1317                code: 220,
1318                enhanced_code: Some(EnhancedStatusCode {
1319                    class: 4,
1320                    subject: 1,
1321                    detail: 0
1322                }),
1323                content: "woot\nmore\ndone".to_string(),
1324                command: None
1325            }
1326        );
1327
1328        // ... but only if the code matches that of the first line
1329        assert_eq!(
1330            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1331            Response {
1332                code: 220,
1333                enhanced_code: Some(EnhancedStatusCode {
1334                    class: 4,
1335                    subject: 1,
1336                    detail: 0
1337                }),
1338                content: "woot\nmore\n5.5.5 done".to_string(),
1339                command: None
1340            }
1341        );
1342    }
1343
1344    #[test]
1345    fn test_extract_hostname() {
1346        assert_eq!(extract_hostname("foo"), "foo");
1347        assert_eq!(extract_hostname("foo."), "foo");
1348        assert_eq!(extract_hostname("foo:25"), "foo");
1349        assert_eq!(extract_hostname("foo.:25"), "foo");
1350        assert_eq!(extract_hostname("[foo]:25"), "foo");
1351        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1352        assert_eq!(extract_hostname("[::1]:25"), "::1");
1353        assert_eq!(extract_hostname("::1:25"), "::1");
1354    }
1355
1356    #[test]
1357    fn test_format_error_command() {
1358        assert_eq!(
1359            format!(
1360                "{:#}",
1361                ClientError::TimeOutRequest {
1362                    commands: vec![Command::DataDot],
1363                    duration: Duration::from_secs(10),
1364                }
1365            ),
1366            "Timed Out after 10s writing cmd=."
1367        );
1368        assert_eq!(
1369            format!(
1370                "{:#}",
1371                ClientError::TimeOutResponse {
1372                    command: Some(Command::MailFrom {
1373                        address: ReversePath::Path(MailPath {
1374                            at_domain_list: vec![],
1375                            mailbox: Mailbox {
1376                                local_part: "user".to_string(),
1377                                domain: Domain::Name("host".to_string())
1378                            }
1379                        }),
1380                        parameters: vec![],
1381                    }),
1382                    duration: Duration::from_secs(10),
1383                }
1384            ),
1385            r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1386        );
1387    }
1388}