use crate::client_types::*;
use crate::{
AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
ReversePath,
};
use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
use hickory_proto::rr::rdata::TLSA;
use memchr::memmem::Finder;
use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
use openssl::x509::{X509Ref, X509};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::time::timeout;
use tokio_rustls::rustls::pki_types::ServerName;
use tracing::Level;
pub use crate::tls::TlsOptions;
pub use {openssl, tokio_rustls};
const MAX_LINE_LEN: usize = 4096;
#[derive(Error, Debug)]
pub enum ClientError {
#[error("response is not UTF8")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Malformed Response: {0}")]
MalformedResponseLine(String),
#[error("Response line is too long")]
ResponseTooLong,
#[error("Not connected")]
NotConnected,
#[error("Command rejected {0:?}")]
Rejected(Response),
#[error("STARTTLS: {0} is not a valid DNS name")]
InvalidDnsName(String),
#[error("Timed Out waiting {duration:?} for response to {command:?}")]
TimeOutResponse {
command: Option<Command>,
duration: Duration,
},
#[error("Timed Out writing {duration:?} {command:?}")]
TimeOutRequest {
command: Command,
duration: Duration,
},
#[error("Timed Out sending message payload data")]
TimeOutData,
#[error("SSL Error: {0}")]
SslErrorStack(#[from] openssl::error::ErrorStack),
#[error("SSL Error: {0}")]
SslError(#[from] openssl::ssl::Error),
#[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EsmtpCapability {
pub name: String,
pub param: Option<String>,
}
#[derive(Clone, Debug)]
pub enum SmtpClientTraceEvent {
Closed,
Read(Vec<u8>),
Write(String),
Diagnostic {
level: tracing::Level,
message: String,
},
}
pub trait DeferredTracer {
fn trace(&self) -> SmtpClientTraceEvent;
}
pub trait SmtpClientTracer: std::fmt::Debug {
fn trace_event(&self, event: SmtpClientTraceEvent);
fn lazy_trace(&self, deferred: &dyn DeferredTracer);
}
struct WriteTracer<'a> {
data: &'a str,
}
impl<'a> DeferredTracer for WriteTracer<'a> {
fn trace(&self) -> SmtpClientTraceEvent {
SmtpClientTraceEvent::Write(self.data.to_string())
}
}
impl<'a> WriteTracer<'a> {
fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
tracer.lazy_trace(&Self { data });
}
}
struct BinWriteTracer<'a> {
data: &'a [u8],
}
impl<'a> DeferredTracer for BinWriteTracer<'a> {
fn trace(&self) -> SmtpClientTraceEvent {
let data = String::from_utf8_lossy(&self.data).to_string();
SmtpClientTraceEvent::Write(data)
}
}
impl<'a> BinWriteTracer<'a> {
fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
tracer.lazy_trace(&Self { data });
}
}
struct ReadTracer<'a> {
data: &'a [u8],
}
impl<'a> DeferredTracer for ReadTracer<'a> {
fn trace(&self) -> SmtpClientTraceEvent {
SmtpClientTraceEvent::Read(self.data.to_vec())
}
}
#[derive(Debug)]
pub struct SmtpClient {
socket: Option<BoxedAsyncReadAndWrite>,
hostname: String,
capabilities: HashMap<String, EsmtpCapability>,
read_buffer: Vec<u8>,
timeouts: SmtpClientTimeouts,
tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
}
fn extract_hostname(hostname: &str) -> &str {
let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
let hostname = if fields.len() == 2 {
fields[1]
} else {
hostname
};
if hostname.starts_with('[') && hostname.ends_with(']') {
&hostname[1..hostname.len() - 1]
} else {
hostname
}
}
impl SmtpClient {
pub async fn new<A: ToSocketAddrs + ToString + Clone>(
addr: A,
timeouts: SmtpClientTimeouts,
) -> std::io::Result<Self> {
let stream = TcpStream::connect(addr.clone()).await?;
stream.set_nodelay(true)?;
Ok(Self::with_stream(stream, addr.to_string(), timeouts))
}
pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
stream: S,
peer_hostname: H,
timeouts: SmtpClientTimeouts,
) -> Self {
let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
Self {
socket: Some(Box::new(stream)),
hostname,
capabilities: HashMap::new(),
read_buffer: Vec::with_capacity(1024),
timeouts,
tracer: None,
}
}
pub fn is_connected(&self) -> bool {
self.socket.is_some()
}
pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
self.tracer.replace(tracer);
}
pub fn timeouts(&self) -> &SmtpClientTimeouts {
&self.timeouts
}
async fn read_line(
&mut self,
timeout_duration: Duration,
cmd: Option<&Command>,
) -> Result<String, ClientError> {
let mut too_long = false;
loop {
let mut iter = self.read_buffer.iter().enumerate();
while let Some((i, &b)) = iter.next() {
if b != b'\r' {
continue;
}
if let Some((_, b'\n')) = iter.next() {
if too_long {
self.read_buffer.drain(0..i + 2);
if let Some(tracer) = &self.tracer {
tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
level: Level::ERROR,
message: "Response too long".to_string(),
});
}
return Err(ClientError::ResponseTooLong);
}
let line = String::from_utf8(self.read_buffer[0..i].to_vec());
self.read_buffer.drain(0..i + 2);
return Ok(line?);
}
}
if self.read_buffer.len() > MAX_LINE_LEN {
self.read_buffer.clear();
too_long = true;
}
let mut data = [0u8; MAX_LINE_LEN];
let size = match self.socket.as_mut() {
Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
Ok(result) => result?,
Err(_) => {
return Err(ClientError::TimeOutResponse {
command: cmd.cloned(),
duration: timeout_duration,
})
}
},
None => return Err(ClientError::NotConnected),
};
if size == 0 {
self.socket.take();
if let Some(tracer) = &self.tracer {
tracer.trace_event(SmtpClientTraceEvent::Closed);
}
return Err(ClientError::NotConnected);
}
if let Some(tracer) = &self.tracer {
tracer.lazy_trace(&ReadTracer {
data: &data[0..size],
});
}
self.read_buffer.extend_from_slice(&data[0..size]);
}
}
pub async fn read_response(
&mut self,
command: Option<&Command>,
timeout_duration: Duration,
) -> Result<Response, ClientError> {
if let Some(sock) = self.socket.as_mut() {
sock.flush().await?;
}
let mut line = self.read_line(timeout_duration, command).await?;
tracing::trace!("recv<-{}: {line}", self.hostname);
let mut parsed = parse_response_line(&line)?;
let mut response_builder = ResponseBuilder::new(&parsed);
let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
while !parsed.is_final {
line = self
.read_line(subsequent_line_timeout_duration, command)
.await?;
parsed = parse_response_line(&line)?;
response_builder
.add_line(&parsed)
.map_err(ClientError::MalformedResponseLine)?;
}
let response = response_builder.build(command.map(|cmd| cmd.encode()));
tracing::trace!("{}: {response:?}", self.hostname);
Ok(response)
}
pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
self.write_command_request(command, false).await?;
self.read_response(Some(command), command.client_timeout(&self.timeouts))
.await
}
async fn write_command_request(
&mut self,
command: &Command,
is_pipeline: bool,
) -> Result<(), ClientError> {
let line = command.encode();
tracing::trace!(
"send->{}: {}{line}",
self.hostname,
if is_pipeline { "(PIPELINE) " } else { "" },
);
match self.socket.as_mut() {
Some(socket) => {
if let Some(tracer) = &self.tracer {
WriteTracer::trace(tracer, &line);
}
match timeout(
command.client_timeout_request(&self.timeouts),
socket.write_all(line.as_bytes()),
)
.await
{
Ok(result) => result.map_err(|_| {
self.socket.take();
ClientError::NotConnected
}),
Err(_) => {
self.socket.take();
return Err(ClientError::TimeOutRequest {
command: command.clone(),
duration: command.client_timeout_request(&self.timeouts),
});
}
}
}
None => Err(ClientError::NotConnected),
}
}
async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
match self.socket.as_mut() {
Some(sock) => {
if let Some(tracer) = &self.tracer {
BinWriteTracer::trace(tracer, &data);
}
match timeout(
Command::Data.client_timeout_request(&self.timeouts),
sock.write_all(data),
)
.await
{
Ok(result) => result.map_err(|_| {
self.socket.take();
ClientError::NotConnected
}),
Err(_) => {
self.socket.take();
Err(ClientError::TimeOutData)
}
}
}
None => Err(ClientError::NotConnected),
}
}
pub async fn pipeline_commands(
&mut self,
commands: Vec<Command>,
) -> Vec<Result<Response, ClientError>> {
let pipeline = self.capabilities.contains_key("PIPELINING");
let mut results: Vec<Result<Response, ClientError>> = vec![];
for cmd in &commands {
if let Err(err) = self.write_command_request(cmd, pipeline).await {
results.push(Err(err.into()));
return results;
}
if !pipeline {
results.push(
self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
.await,
);
}
}
if pipeline {
for cmd in &commands {
results.push(
self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
.await,
);
}
}
results
}
pub async fn ehlo_lhlo(
&mut self,
ehlo_name: &str,
use_lmtp: bool,
) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
if use_lmtp {
self.lhlo(ehlo_name).await
} else {
self.ehlo(ehlo_name).await
}
}
pub async fn lhlo(
&mut self,
ehlo_name: &str,
) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
let response = self
.send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
.await?;
self.ehlo_common(response)
}
pub async fn ehlo(
&mut self,
ehlo_name: &str,
) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
let response = self
.send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
.await?;
self.ehlo_common(response)
}
fn ehlo_common(
&mut self,
response: Response,
) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
if response.code != 250 {
return Err(ClientError::Rejected(response));
}
let mut capabilities = HashMap::new();
for line in response.content.lines().skip(1) {
let mut fields = line.splitn(2, ' ');
if let Some(name) = fields.next() {
let param = fields.next().map(|s| s.to_string());
let cap = EsmtpCapability {
name: name.to_string(),
param,
};
capabilities.insert(name.to_ascii_uppercase(), cap);
}
}
self.capabilities = capabilities;
Ok(&self.capabilities)
}
pub async fn auth_plain(
&mut self,
username: &str,
password: Option<&str>,
) -> Result<(), ClientError> {
let password = password.unwrap_or("");
let payload = format!("\x00{username}\x00{password}");
let payload = data_encoding::BASE64.encode(payload.as_bytes());
let response = self
.send_command(&Command::Auth {
sasl_mech: "PLAIN".to_string(),
initial_response: Some(payload),
})
.await?;
if response.code != 235 {
return Err(ClientError::Rejected(response));
}
Ok(())
}
pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
let resp = self.send_command(&Command::StartTls).await?;
if resp.code != 220 {
return Err(ClientError::Rejected(resp));
}
let mut handshake_error = None;
let mut tls_info = TlsInformation::default();
let stream: BoxedAsyncReadAndWrite =
if options.prefer_openssl || !options.dane_tlsa.is_empty() {
let connector = options.build_openssl_connector(&self.hostname)?;
let ssl = connector.into_ssl(self.hostname.as_str())?;
let (stream, dup_stream) = match self.socket.take() {
Some(s) => {
let d = s.try_dup();
(s, d)
}
None => return Err(ClientError::NotConnected),
};
let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
handshake_error.replace(format!("{err:#}"));
}
tls_info.provider_name = "openssl".to_string();
tls_info.cipher = match ssl_stream.ssl().current_cipher() {
Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
None => String::new(),
};
tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
if let Some(cert) = ssl_stream.ssl().peer_certificate() {
tls_info.subject_name = subject_name(&cert);
}
if let Ok(authority) = ssl_stream.ssl().dane_authority() {
if let Some(cert) = &authority.cert {
tls_info.subject_name = subject_name(cert);
}
}
match (&handshake_error, dup_stream) {
(Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
drop(ssl_stream);
Box::new(dup_stream)
}
_ => Box::new(ssl_stream),
}
} else {
tls_info.provider_name = "rustls".to_string();
let connector = options.build_tls_connector();
let server_name = match IpAddr::from_str(self.hostname.as_str()) {
Ok(ip) => ServerName::IpAddress(ip.into()),
Err(_) => ServerName::try_from(self.hostname.clone())
.map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
};
match connector
.connect(
server_name,
match self.socket.take() {
Some(s) => s,
None => return Err(ClientError::NotConnected),
},
)
.into_fallible()
.await
{
Ok(stream) => {
let (_, conn) = stream.get_ref();
tls_info.cipher = match conn.negotiated_cipher_suite() {
Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
None => String::new(),
};
tls_info.protocol_version = match conn.protocol_version() {
Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
None => String::new(),
};
if let Some(certs) = conn.peer_certificates() {
let peer_cert = &certs[0];
if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
tls_info.subject_name = subject_name(&cert);
}
}
Box::new(stream)
}
Err((err, stream)) => {
handshake_error.replace(format!("{err:#}"));
stream
}
}
};
if let Some(tracer) = &self.tracer {
tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
level: Level::INFO,
message: match &handshake_error {
Some(error) => format!("STARTTLS handshake failed: {error:?}"),
None => format!("STARTTLS handshake -> {tls_info:?}"),
},
});
}
self.socket.replace(stream);
Ok(match handshake_error {
Some(error) => TlsStatus::FailedHandshake(error),
None => TlsStatus::Info(tls_info),
})
}
pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
&mut self,
sender: SENDER,
recipient: RECIP,
data: B,
) -> Result<Response, ClientError> {
let sender = sender.into();
let recipient = recipient.into();
let data: &[u8] = data.as_ref();
let stuffed;
let data = match apply_dot_stuffing(data) {
Some(d) => {
stuffed = d;
&stuffed
}
None => data,
};
let data_is_8bit = data.iter().any(|&b| b >= 0x80);
let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
let mut mail_from_params = vec![];
if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
mail_from_params.push(EsmtpParameter {
name: "BODY".to_string(),
value: Some("8BITMIME".to_string()),
});
}
if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
mail_from_params.push(EsmtpParameter {
name: "SMTPUTF8".to_string(),
value: None,
});
}
let mut responses = self
.pipeline_commands(vec![
Command::Rset,
Command::MailFrom {
address: sender,
parameters: mail_from_params,
},
Command::RcptTo {
address: recipient,
parameters: vec![],
},
Command::Data,
])
.await;
if responses.is_empty() {
return Err(ClientError::NotConnected);
}
let rset_resp = responses.remove(0)?;
let mail_resp = responses.remove(0)?;
let rcpt_resp = responses.remove(0)?;
let data_resp = responses.remove(0)?;
if data_resp.code == 354
&& (rset_resp.code != 250 || mail_resp.code != 250 || rcpt_resp.code != 250)
{
self.write_data_with_timeout(b".\r\n").await?;
let data_dot = Command::DataDot;
let _ = self
.read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
.await?;
}
if rset_resp.code != 250 {
return Err(ClientError::Rejected(rset_resp));
}
if mail_resp.code != 250 {
return Err(ClientError::Rejected(mail_resp));
}
if rcpt_resp.code != 250 {
return Err(ClientError::Rejected(rcpt_resp));
}
if data_resp.code != 354 {
return Err(ClientError::Rejected(data_resp));
}
let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
tracing::trace!("message data is {} bytes", data.len());
self.write_data_with_timeout(&data).await?;
let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
self.write_data_with_timeout(marker.as_bytes()).await?;
let data_dot = Command::DataDot;
let resp = self
.read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
.await?;
if resp.code != 250 {
return Err(ClientError::Rejected(resp));
}
Ok(resp)
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
pub enum TlsStatus {
FailedHandshake(String),
Info(TlsInformation),
}
#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
pub struct TlsInformation {
pub cipher: String,
pub protocol_version: String,
pub subject_name: Vec<String>,
pub provider_name: String,
}
impl Drop for SmtpClient {
fn drop(&mut self) {
if let Some(tracer) = &self.tracer {
if self.socket.is_some() {
tracer.trace_event(SmtpClientTraceEvent::Closed);
}
}
}
}
fn parse_response_line(line: &str) -> Result<ResponseLine, ClientError> {
if line.len() < 4 {
return Err(ClientError::MalformedResponseLine(line.to_string()));
}
match line.as_bytes()[3] {
b' ' | b'-' => match line[0..3].parse::<u16>() {
Ok(code) => Ok(ResponseLine {
code,
is_final: line.as_bytes()[3] == b' ',
content: &line[4..],
}),
Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
},
_ => Err(ClientError::MalformedResponseLine(line.to_string())),
}
}
impl TlsOptions {
pub fn build_openssl_connector(
&self,
hostname: &str,
) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
tracing::trace!("build_openssl_connector for {hostname}");
let mut builder =
openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
if let Some(list) = &self.openssl_cipher_list {
builder.set_cipher_list(&list)?;
}
if let Some(suites) = &self.openssl_cipher_suites {
builder.set_ciphersuites(&suites)?;
}
if let Some(options) = &self.openssl_options {
builder.clear_options(SslOptions::all());
builder.set_options(*options);
}
if self.insecure {
builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
}
if !self.dane_tlsa.is_empty() {
builder.dane_enable()?;
builder.set_no_dane_ee_namechecks();
}
let connector = builder.build();
let mut config = connector.configure()?;
if !self.dane_tlsa.is_empty() {
config.dane_enable(hostname)?;
let mut any_usable = false;
for tlsa in &self.dane_tlsa {
let usable = config.dane_tlsa_add(
match tlsa.cert_usage() {
CertUsage::CA => DaneUsage::PKIX_TA,
CertUsage::Service => DaneUsage::PKIX_EE,
CertUsage::TrustAnchor => DaneUsage::DANE_TA,
CertUsage::DomainIssued => DaneUsage::DANE_EE,
CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
CertUsage::Private => DaneUsage::PRIV_CERT,
},
match tlsa.selector() {
Selector::Full => DaneSelector::CERT,
Selector::Spki => DaneSelector::SPKI,
Selector::Unassigned(n) => DaneSelector::from_raw(n),
Selector::Private => DaneSelector::PRIV_SEL,
},
match tlsa.matching() {
Matching::Raw => DaneMatchType::FULL,
Matching::Sha256 => DaneMatchType::SHA2_256,
Matching::Sha512 => DaneMatchType::SHA2_512,
Matching::Unassigned(n) => DaneMatchType::from_raw(n),
Matching::Private => DaneMatchType::PRIV_MATCH,
},
tlsa.cert_data(),
)?;
tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
if usable {
any_usable = true;
}
}
if !any_usable {
return Err(ClientError::NoUsableDaneTlsa {
hostname: hostname.to_string(),
tlsa: self.dane_tlsa.clone(),
});
}
}
Ok(config)
}
}
fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
if !data.starts_with(b".") && LFDOT.find(&data).is_none() {
return None;
}
let mut stuffed = vec![];
if data.starts_with(b".") {
stuffed.push(b'.');
}
let mut last_idx = 0;
for i in LFDOT.find_iter(&data) {
stuffed.extend_from_slice(&data[last_idx..=i]);
stuffed.push(b'.');
last_idx = i + 1;
}
stuffed.extend_from_slice(&data[last_idx..]);
Some(stuffed)
}
fn subject_name(cert: &X509Ref) -> Vec<String> {
let mut subject_name = vec![];
for entry in cert.subject_name().entries() {
if let Ok(obj) = entry.object().nid().short_name() {
if let Ok(name) = entry.data().as_utf8() {
subject_name.push(format!("{obj}={name}"));
}
}
}
subject_name
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_stuffing() {
assert_eq!(apply_dot_stuffing(b"foo"), None);
assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
assert_eq!(
apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
b"foo\n..bar\n...baz\n"
);
}
#[test]
fn response_line_parsing() {
assert_eq!(
parse_response_line("220 woot").unwrap(),
ResponseLine {
code: 220,
is_final: true,
content: "woot"
}
);
assert_eq!(
parse_response_line("220-woot").unwrap(),
ResponseLine {
code: 220,
is_final: false,
content: "woot"
}
);
assert!(matches!(
parse_response_line("220_woot"),
Err(ClientError::MalformedResponseLine(_))
));
assert!(matches!(
parse_response_line("not really"),
Err(ClientError::MalformedResponseLine(_))
));
}
fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
let mut parsed = parse_response_line(lines[0])?;
let mut b = ResponseBuilder::new(&parsed);
for line in &lines[1..] {
parsed = parse_response_line(line)?;
b.add_line(&parsed)
.map_err(ClientError::MalformedResponseLine)?;
}
assert!(parsed.is_final);
Ok(b.build(None))
}
#[test]
fn multi_line_response() {
assert_eq!(
parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
Response {
code: 220,
enhanced_code: None,
content: "woot\nmore\ndone".to_string(),
command: None
}
);
let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
assert!(
matches!(
res,
ClientError::MalformedResponseLine(ref err) if err == "221-more"
),
"got error {res:?}"
);
let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
assert!(
matches!(
res,
ClientError::MalformedResponseLine(ref err) if err == "221 done"
),
"got error {res:?}"
);
assert_eq!(
parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
Response {
code: 220,
enhanced_code: Some(EnhancedStatusCode {
class: 4,
subject: 1,
detail: 0
}),
content: "woot\nmore\ndone".to_string(),
command: None
}
);
assert_eq!(
parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
Response {
code: 220,
enhanced_code: Some(EnhancedStatusCode {
class: 4,
subject: 1,
detail: 0
}),
content: "woot\nmore\ndone".to_string(),
command: None
}
);
assert_eq!(
parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
Response {
code: 220,
enhanced_code: Some(EnhancedStatusCode {
class: 4,
subject: 1,
detail: 0
}),
content: "woot\nmore\n5.5.5 done".to_string(),
command: None
}
);
}
#[test]
fn test_extract_hostname() {
assert_eq!(extract_hostname("foo"), "foo");
assert_eq!(extract_hostname("foo:25"), "foo");
assert_eq!(extract_hostname("[foo]:25"), "foo");
assert_eq!(extract_hostname("[::1]:25"), "::1");
assert_eq!(extract_hostname("::1:25"), "::1");
}
}