use amqprs::callbacks::{DefaultChannelCallback, DefaultConnectionCallback};
use amqprs::channel::{BasicPublishArguments, Channel, ConfirmSelectArguments};
use amqprs::connection::{Connection, OpenConnectionArguments};
use amqprs::tls::TlsAdaptor;
use amqprs::{BasicProperties, FieldTable, TimeStamp};
use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
use kumo_server_memory::subscribe_to_memory_status_changes_async;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
static POOLS: LazyLock<Mutex<Pools>> = LazyLock::new(|| Pools::new());
struct Pools {
map: HashMap<ConnectionInfo, Arc<Pool<ConnectionManager>>>,
}
impl Pools {
pub fn new() -> Mutex<Self> {
tokio::spawn(Self::memory_monitor());
Mutex::new(Self {
map: HashMap::new(),
})
}
pub fn get(&self, info: &ConnectionInfo) -> Option<&Arc<Pool<ConnectionManager>>> {
self.map.get(info)
}
pub fn insert(&mut self, info: ConnectionInfo, pool: Arc<Pool<ConnectionManager>>) {
self.map.insert(info, pool);
}
async fn memory_monitor() {
tracing::debug!("starting memory monitor");
tokio::time::sleep(Duration::from_secs(10)).await;
let mut memory_status = subscribe_to_memory_status_changes_async().await;
while let Ok(()) = memory_status.changed().await {
if kumo_server_memory::get_headroom() == 0 {
Self::purge().await;
}
}
}
fn all_pools() -> Vec<Arc<Pool<ConnectionManager>>> {
POOLS.lock().map.values().cloned().collect()
}
async fn purge() {
let pools = Self::all_pools();
for pool in pools {
let info = &pool.manager().0;
let result = pool.retain(|_, _| false);
tracing::error!(
"purging {} amqprs connections for {}:{:?}",
result.removed.len(),
info.host,
info.port
);
for connection in result.removed {
if let Err(err) = connection.channel.close().await {
tracing::error!(
"Error closing channel to {}:{:?}: {err:#}",
info.host,
info.port
);
}
if let Err(err) = connection.connection.connection.close().await {
tracing::error!(
"Error closing connection to {}:{:?}: {err:#}",
info.host,
info.port
);
}
}
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Hash, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct ConnectionInfo {
pub host: String,
pub port: Option<u16>,
pub username: Option<String>,
pub password: Option<String>,
pub vhost: Option<String>,
pub connection_name: Option<String>,
pub heartbeat: Option<u16>,
#[serde(default)]
pub enable_tls: bool,
pub root_ca_cert: Option<String>,
pub client_cert: Option<String>,
pub client_private_key: Option<String>,
#[serde(default)]
pub confirm_select: bool,
#[serde(default)]
pub pool_size: Option<usize>,
#[serde(default, with = "duration_serde")]
pub connect_timeout: Option<Duration>,
#[serde(default, with = "duration_serde")]
pub recycle_timeout: Option<Duration>,
#[serde(default, with = "duration_serde")]
pub wait_timeout: Option<Duration>,
#[serde(default, with = "duration_serde")]
pub publish_timeout: Option<Duration>,
}
pub struct ConnectionManager(ConnectionInfo);
pub struct ConnectionWithInfo {
connection: Connection,
#[allow(unused)]
info: ConnectionInfo,
}
pub struct ConnectionAndChannel {
connection: ConnectionWithInfo,
channel: Channel,
}
impl Manager for ConnectionManager {
type Type = ConnectionAndChannel;
type Error = anyhow::Error;
async fn create(&self) -> Result<Self::Type, Self::Error> {
let connection = self.0.connect().await?;
connection
.register_callback(DefaultConnectionCallback)
.await?;
let channel = connection.open_channel(None).await?;
channel.register_callback(DefaultChannelCallback).await?;
if self.0.confirm_select {
channel
.confirm_select(ConfirmSelectArguments::default())
.await?;
}
Ok(ConnectionAndChannel {
connection: ConnectionWithInfo {
connection,
info: self.0.clone(),
},
channel,
})
}
async fn recycle(
&self,
conn: &mut Self::Type,
_metrics: &Metrics,
) -> RecycleResult<anyhow::Error> {
if conn.connection.connection.is_open() && conn.channel.is_open() {
Ok(())
} else {
Err(RecycleError::message("channel/connection is closed"))
}
}
}
impl ConnectionInfo {
pub async fn connect(&self) -> anyhow::Result<Connection> {
let mut args = OpenConnectionArguments::new(
&self.host,
self.port.unwrap_or(5672),
self.username.as_deref().unwrap_or("guest"),
self.password.as_deref().unwrap_or("guest"),
);
if let Some(vhost) = &self.vhost {
args.virtual_host(vhost);
}
if let Some(name) = &self.connection_name {
args.connection_name(name);
}
if let Some(hb) = self.heartbeat {
args.heartbeat(hb);
}
if self.enable_tls {
let adaptor = match (&self.client_cert, &self.client_private_key) {
(Some(cert), Some(key)) => TlsAdaptor::with_client_auth(
self.root_ca_cert.as_deref().map(Path::new),
Path::new(cert),
Path::new(key),
self.host.to_string(),
)?,
(None, None) => TlsAdaptor::without_client_auth(
self.root_ca_cert.as_deref().map(Path::new),
self.host.to_string(),
)?,
_ => anyhow::bail!(
"Either both client_cert and client_private_key must be specified, or neither"
),
};
args.tls_adaptor(adaptor);
}
let connection = Connection::open(&args).await?;
Ok(connection)
}
pub fn get_pool(&self) -> anyhow::Result<Arc<Pool<ConnectionManager>>> {
let mut pools = POOLS.lock();
if let Some(pool) = pools.get(self) {
return Ok(pool.clone());
}
let mut builder = Pool::builder(ConnectionManager(self.clone()))
.runtime(deadpool::Runtime::Tokio1)
.create_timeout(self.connect_timeout)
.recycle_timeout(self.recycle_timeout)
.wait_timeout(self.wait_timeout);
if let Some(limit) = self.pool_size {
builder = builder.max_size(limit);
}
let pool = Arc::new(builder.build()?);
pools.insert(self.clone(), pool.clone());
Ok(pool)
}
}
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct PublishParams {
pub routing_key: String,
pub payload: String,
pub connection: ConnectionInfo,
pub app_id: Option<String>,
pub cluster_id: Option<String>,
pub content_encoding: Option<String>,
pub content_type: Option<String>,
pub correlation_id: Option<String>,
pub delivery_mode: Option<u8>,
pub expiration: Option<String>,
pub headers: Option<FieldTable>,
pub message_id: Option<String>,
pub message_type: Option<String>,
pub priority: Option<u8>,
pub reply_to: Option<String>,
pub timestamp: Option<TimeStamp>,
pub user_id: Option<String>,
#[serde(default)]
pub exchange: String,
#[serde(default)]
pub mandatory: bool,
#[serde(default)]
pub immediate: bool,
}
pub async fn publish(params: PublishParams) -> anyhow::Result<()> {
kumo_server_runtime::get_main_runtime()
.spawn(async move {
let connection = params
.connection
.get_pool()?
.get()
.await
.map_err(|err| anyhow::anyhow!("{err:#}"))?;
let mut props = BasicProperties::default();
if let Some(v) = ¶ms.app_id {
props.with_app_id(v);
}
if let Some(v) = ¶ms.cluster_id {
props.with_cluster_id(v);
}
if let Some(v) = ¶ms.content_encoding {
props.with_content_encoding(v);
}
if let Some(v) = ¶ms.content_type {
props.with_content_type(v);
}
if let Some(v) = ¶ms.correlation_id {
props.with_correlation_id(v);
}
if let Some(v) = params.delivery_mode {
props.with_delivery_mode(v);
}
if let Some(v) = ¶ms.expiration {
props.with_expiration(v);
}
if let Some(v) = params.headers {
props.with_headers(v);
}
if let Some(v) = ¶ms.message_id {
props.with_message_id(v);
}
if let Some(v) = ¶ms.message_type {
props.with_message_type(v);
}
if let Some(v) = params.priority {
props.with_priority(v);
}
if let Some(v) = ¶ms.reply_to {
props.with_reply_to(v);
}
if let Some(v) = params.timestamp {
props.with_timestamp(v);
}
if let Some(v) = ¶ms.user_id {
props.with_user_id(v);
}
let args = BasicPublishArguments {
exchange: params.exchange,
routing_key: params.routing_key,
mandatory: params.mandatory,
immediate: params.immediate,
};
let timeout_duration = params
.connection
.publish_timeout
.unwrap_or_else(|| Duration::from_secs(60));
tokio::time::timeout(
timeout_duration,
connection
.channel
.basic_publish(props, params.payload.into_bytes(), args),
)
.await??;
Ok(())
})
.await?
}