1use config::{any_err, get_or_create_sub_module};
2use dashmap::DashMap;
3use mlua::prelude::*;
4use mlua::{Lua, UserDataMethods};
5use mod_memoize::CacheValue;
6use std::sync::{Arc, LazyLock};
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{
9 Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
10};
11use tokio::time::Duration;
12
13enum SenderWrapper {
14 Bounded(Sender<CacheValue>),
15 Unbounded(UnboundedSender<CacheValue>),
16}
17
18impl SenderWrapper {
19 pub async fn send(&self, value: CacheValue) -> anyhow::Result<()> {
20 match self {
21 Self::Bounded(s) => {
22 s.send(value).await?;
23 }
24 Self::Unbounded(s) => {
25 s.send(value)?;
26 }
27 };
28 Ok(())
29 }
30
31 pub fn try_send(&self, value: CacheValue) -> anyhow::Result<()> {
32 match self {
33 Self::Bounded(s) => {
34 s.try_send(value)?;
35 }
36 Self::Unbounded(s) => {
37 s.send(value)?;
38 }
39 };
40 Ok(())
41 }
42
43 pub async fn send_timeout(&self, value: CacheValue, duration: Duration) -> anyhow::Result<()> {
44 match self {
45 Self::Bounded(s) => {
46 s.send_timeout(value, duration).await?;
47 }
48 Self::Unbounded(s) => {
49 s.send(value)?;
50 }
51 };
52 Ok(())
53 }
54 pub fn is_closed(&self) -> bool {
55 match self {
56 Self::Bounded(r) => r.is_closed(),
57 Self::Unbounded(r) => r.is_closed(),
58 }
59 }
60}
61
62enum ReceiverWrapper {
63 Bounded(Receiver<CacheValue>),
64 Unbounded(UnboundedReceiver<CacheValue>),
65}
66
67impl ReceiverWrapper {
68 pub async fn recv(&mut self) -> Option<CacheValue> {
69 match self {
70 Self::Bounded(r) => r.recv().await,
71 Self::Unbounded(r) => r.recv().await,
72 }
73 }
74
75 pub fn try_recv(&mut self) -> Option<CacheValue> {
76 match self {
77 Self::Bounded(r) => r.try_recv().ok(),
78 Self::Unbounded(r) => r.try_recv().ok(),
79 }
80 }
81
82 pub async fn recv_many(&mut self, limit: usize) -> Vec<CacheValue> {
83 let mut buffer = vec![];
84 match self {
85 Self::Bounded(r) => r.recv_many(&mut buffer, limit).await,
86 Self::Unbounded(r) => r.recv_many(&mut buffer, limit).await,
87 };
88
89 buffer
90 }
91
92 pub fn close(&mut self) {
93 match self {
94 Self::Bounded(r) => r.close(),
95 Self::Unbounded(r) => r.close(),
96 }
97 }
98
99 pub fn is_empty(&self) -> bool {
100 match self {
101 Self::Bounded(r) => r.is_empty(),
102 Self::Unbounded(r) => r.is_empty(),
103 }
104 }
105
106 pub fn len(&self) -> usize {
107 match self {
108 Self::Bounded(r) => r.len(),
109 Self::Unbounded(r) => r.len(),
110 }
111 }
112}
113
114struct Queue {
115 sender: Mutex<Option<Arc<SenderWrapper>>>,
116 receiver: Arc<Mutex<ReceiverWrapper>>,
117}
118
119struct QueueHandle(Arc<Queue>);
120
121impl QueueHandle {
122 async fn get_sender(&self) -> Option<Arc<SenderWrapper>> {
123 let sender = self.0.sender.lock().await;
124 sender.clone()
125 }
126}
127
128impl LuaUserData for QueueHandle {
129 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
130 methods.add_async_method(
131 "send",
132 move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
133 this.get_sender()
134 .await
135 .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
136 .send(value)
137 .await
138 .map_err(any_err)?;
139 Ok(())
140 },
141 );
142
143 methods.add_async_method(
144 "send_timeout",
145 move |_lua,
146 this: LuaUserDataRef<QueueHandle>,
147 (value, timeout_seconds): (CacheValue, f32)| async move {
148 this.get_sender()
149 .await
150 .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
151 .send_timeout(value, Duration::from_secs_f32(timeout_seconds))
152 .await
153 .map_err(any_err)?;
154 Ok(())
155 },
156 );
157
158 methods.add_async_method(
159 "try_send",
160 move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
161 match this
162 .get_sender()
163 .await
164 .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
165 .try_send(value)
166 {
167 Ok(()) => Ok(true),
168 Err(_) => Ok(false),
169 }
170 },
171 );
172
173 methods.add_async_method(
174 "close",
175 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
176 let mut sender = this.get_sender().await;
177 sender.take();
178
179 if let Ok(mut rx) = this.0.receiver.try_lock() {
180 rx.close();
181 }
182 Ok(())
183 },
184 );
185
186 methods.add_async_method(
187 "is_closed",
188 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
189 Ok(match this.get_sender().await {
190 None => true,
191 Some(sender) => sender.is_closed(),
192 })
193 },
194 );
195
196 methods.add_async_method(
197 "is_empty",
198 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
199 let rx = this.0.receiver.try_lock().map_err(any_err)?;
200 Ok(rx.is_empty())
201 },
202 );
203
204 methods.add_async_method(
205 "len",
206 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
207 let rx = this.0.receiver.try_lock().map_err(any_err)?;
208 Ok(rx.len())
209 },
210 );
211
212 methods.add_async_method(
213 "recv",
214 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
215 let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
216 Ok(rx.recv().await)
217 },
218 );
219
220 methods.add_async_method(
221 "try_recv",
222 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
223 let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
224 Ok(rx.try_recv())
225 },
226 );
227
228 methods.add_async_method(
229 "recv_many",
230 move |_lua, this: LuaUserDataRef<QueueHandle>, limit: usize| async move {
231 let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
232 Ok(rx.recv_many(limit).await)
233 },
234 );
235 }
236}
237
238static QUEUES: LazyLock<DashMap<String, Arc<Queue>>> = LazyLock::new(DashMap::new);
239
240impl Queue {
241 pub fn define_unbounded(name: &str) -> anyhow::Result<QueueHandle> {
242 let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
243 let (sender, receiver) = unbounded_channel();
244 Arc::new(Queue {
245 sender: Mutex::new(Some(Arc::new(SenderWrapper::Unbounded(sender)))),
246 receiver: Arc::new(Mutex::new(ReceiverWrapper::Unbounded(receiver))),
247 })
248 });
249
250 Ok(QueueHandle(Arc::clone(queue.value())))
251 }
252
253 pub fn define_bounded(name: &str, buffer: usize) -> anyhow::Result<QueueHandle> {
254 let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
255 let (sender, receiver) = channel(buffer);
256 Arc::new(Queue {
257 sender: Mutex::new(Some(Arc::new(SenderWrapper::Bounded(sender)))),
258 receiver: Arc::new(Mutex::new(ReceiverWrapper::Bounded(receiver))),
259 })
260 });
261
262 Ok(QueueHandle(Arc::clone(queue.value())))
263 }
264}
265
266pub fn register(lua: &Lua) -> anyhow::Result<()> {
267 let kumo_mpsc = get_or_create_sub_module(lua, "mpsc")?;
268
269 kumo_mpsc.set(
270 "define",
271 lua.create_function(|_lua, (name, buffer): (String, Option<usize>)| {
272 match buffer {
273 Some(buffer) => Queue::define_bounded(&name, buffer),
274 None => Queue::define_unbounded(&name),
275 }
276 .map_err(any_err)
277 })?,
278 )?;
279
280 Ok(())
281}