1use config::{any_err, from_lua_value};
2use lapin::options::BasicPublishOptions;
3use lapin::publisher_confirm::{Confirmation, PublisherConfirm};
4use lapin::{BasicProperties, Channel, Connection, ConnectionProperties};
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9use tokio::time::timeout;
10
11#[derive(Deserialize, Debug)]
12struct PublishParams {
13 routing_key: String,
14 payload: String,
15
16 #[serde(default)]
17 exchange: String,
18 #[serde(default)]
19 options: BasicPublishOptions,
20 #[serde(default)]
21 properties: BasicProperties,
22}
23
24struct ChannelHolder {
25 channel: Channel,
26 connection: Connection,
27}
28
29#[derive(Clone)]
30pub struct AMQPClient {
31 holder: Arc<ChannelHolder>,
32}
33
34impl LuaUserData for AMQPClient {
35 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
36 methods.add_async_method("publish", |lua, this, value: Value| async move {
37 let params: PublishParams = from_lua_value(&lua, value)?;
38
39 let confirm = this
40 .holder
41 .channel
42 .basic_publish(
43 ¶ms.exchange,
44 ¶ms.routing_key,
45 params.options,
46 params.payload.as_bytes(),
47 params.properties,
48 )
49 .await
50 .map_err(any_err)?;
51
52 Ok(Confirm {
53 confirm: Arc::new(Mutex::new(Some(confirm))),
54 })
55 });
56
57 methods.add_async_method(
58 "publish_with_timeout",
59 |lua, this, (value, duration_millis): (Value, u64)| async move {
60 let params: PublishParams = from_lua_value(&lua, value)?;
61
62 let publish = async {
63 let confirm = this
64 .holder
65 .channel
66 .basic_publish(
67 ¶ms.exchange,
68 ¶ms.routing_key,
69 params.options,
70 params.payload.as_bytes(),
71 params.properties,
72 )
73 .await
74 .map_err(any_err)?;
75
76 wait_confirmation(&lua, confirm).await
77 };
78
79 let duration = std::time::Duration::from_millis(duration_millis);
80 timeout(duration, publish)
81 .await
82 .map_err(any_err)?
83 .map_err(any_err)
84 },
85 );
86
87 methods.add_async_method("close", |_lua, this, _: ()| async move {
88 this.holder.channel.close(200, "").await.map_err(any_err)?;
89 this.holder
90 .connection
91 .close(200, "")
92 .await
93 .map_err(any_err)?;
94 Ok(())
95 });
96 }
97}
98
99#[derive(Clone)]
100struct Confirm {
101 confirm: Arc<Mutex<Option<PublisherConfirm>>>,
102}
103
104#[derive(Serialize, Debug)]
105enum ConfirmStatus {
106 Ack,
107 Nack,
108 NotRequested,
109}
110
111impl ConfirmStatus {
112 fn from_confirmation(confirm: &Confirmation) -> Self {
113 if confirm.is_ack() {
114 Self::Ack
115 } else if confirm.is_nack() {
116 Self::Nack
117 } else {
118 Self::NotRequested
119 }
120 }
121}
122
123#[derive(Serialize, Debug)]
124struct ConfirmResult {
125 status: ConfirmStatus,
126 reply_code: Option<u64>,
127 reply_text: Option<String>,
128}
129
130async fn wait_confirmation(lua: &Lua, confirm: PublisherConfirm) -> mlua::Result<Value> {
131 let confirmation = confirm.await.map_err(any_err)?;
132 let status = ConfirmStatus::from_confirmation(&confirmation);
133 let (reply_code, reply_text) = if let Some(msg) = confirmation.take_message() {
134 (
135 Some(msg.reply_code.into()),
136 Some(msg.reply_text.as_str().to_string()),
137 )
138 } else {
139 (None, None)
140 };
141
142 let confirmation = ConfirmResult {
143 status,
144 reply_code,
145 reply_text,
146 };
147
148 let result = lua.to_value_with(&confirmation, config::serialize_options())?;
149 Ok(result)
150}
151
152impl LuaUserData for Confirm {
153 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
154 methods.add_async_method("wait", |lua, this, _: ()| async move {
155 let confirm = this
156 .confirm
157 .lock()
158 .unwrap()
159 .take()
160 .ok_or_else(|| mlua::Error::external("confirmation already taken!?"))?;
161
162 wait_confirmation(&lua, confirm).await
163 })
164 }
165}
166
167pub async fn build_client(uri: String) -> anyhow::Result<AMQPClient> {
168 let options = ConnectionProperties::default()
169 .with_executor(
170 tokio_executor_trait::Tokio::default()
171 .with_handle(kumo_server_runtime::get_main_runtime()),
172 )
173 .with_reactor(tokio_reactor_trait::Tokio);
174
175 let connect_timeout = tokio::time::Duration::from_secs(20);
176
177 let connection = timeout(connect_timeout, Connection::connect(&uri, options))
178 .await
179 .map_err(any_err)?
180 .map_err(any_err)?;
181
182 connection.on_error(|err| {
183 tracing::error!("RabbitMQ connection broken {err:#}");
184 });
185
186 let channel = connection.create_channel().await.map_err(any_err)?;
187
188 Ok(AMQPClient {
189 holder: Arc::new(ChannelHolder {
190 connection,
191 channel,
192 }),
193 })
194}