1use config::{any_err, get_or_create_sub_module};
2use dashmap::DashMap;
3use mlua::prelude::*;
4use mlua::{Lua, UserDataMethods};
5use mod_memoize::CacheValue;
6use std::sync::{Arc, LazyLock};
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{
9 Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
10};
11use tokio::time::Duration;
12
13enum SenderWrapper {
14 Bounded(Sender<CacheValue>),
15 Unbounded(UnboundedSender<CacheValue>),
16}
17
18impl SenderWrapper {
19 pub async fn send(&self, value: CacheValue) -> anyhow::Result<()> {
20 match self {
21 Self::Bounded(s) => {
22 s.send(value).await?;
23 }
24 Self::Unbounded(s) => {
25 s.send(value)?;
26 }
27 };
28 Ok(())
29 }
30
31 pub fn try_send(&self, value: CacheValue) -> anyhow::Result<()> {
32 match self {
33 Self::Bounded(s) => {
34 s.try_send(value)?;
35 }
36 Self::Unbounded(s) => {
37 s.send(value)?;
38 }
39 };
40 Ok(())
41 }
42
43 pub async fn send_timeout(&self, value: CacheValue, duration: Duration) -> anyhow::Result<()> {
44 match self {
45 Self::Bounded(s) => {
46 s.send_timeout(value, duration).await?;
47 }
48 Self::Unbounded(s) => {
49 s.send(value)?;
50 }
51 };
52 Ok(())
53 }
54 pub fn is_closed(&self) -> bool {
55 match self {
56 Self::Bounded(r) => r.is_closed(),
57 Self::Unbounded(r) => r.is_closed(),
58 }
59 }
60}
61
62enum ReceiverWrapper {
63 Bounded(Receiver<CacheValue>),
64 Unbounded(UnboundedReceiver<CacheValue>),
65}
66
67impl ReceiverWrapper {
68 pub async fn recv(&mut self) -> Option<CacheValue> {
69 match self {
70 Self::Bounded(r) => r.recv().await,
71 Self::Unbounded(r) => r.recv().await,
72 }
73 }
74
75 pub fn try_recv(&mut self) -> Option<CacheValue> {
76 match self {
77 Self::Bounded(r) => r.try_recv().ok(),
78 Self::Unbounded(r) => r.try_recv().ok(),
79 }
80 }
81
82 pub async fn recv_many(&mut self, limit: usize) -> Vec<CacheValue> {
83 let mut buffer = vec![];
84 match self {
85 Self::Bounded(r) => r.recv_many(&mut buffer, limit).await,
86 Self::Unbounded(r) => r.recv_many(&mut buffer, limit).await,
87 };
88
89 buffer
90 }
91
92 pub fn close(&mut self) {
93 match self {
94 Self::Bounded(r) => r.close(),
95 Self::Unbounded(r) => r.close(),
96 }
97 }
98
99 pub fn is_empty(&self) -> bool {
100 match self {
101 Self::Bounded(r) => r.is_empty(),
102 Self::Unbounded(r) => r.is_empty(),
103 }
104 }
105
106 pub fn len(&self) -> usize {
107 match self {
108 Self::Bounded(r) => r.len(),
109 Self::Unbounded(r) => r.len(),
110 }
111 }
112}
113
114struct Queue {
115 sender: Arc<SenderWrapper>,
116 receiver: Arc<Mutex<ReceiverWrapper>>,
117}
118
119struct QueueHandle(Arc<Queue>);
120
121impl LuaUserData for QueueHandle {
122 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
123 methods.add_async_method(
124 "send",
125 move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
126 this.0.sender.send(value).await.map_err(any_err)?;
127 Ok(())
128 },
129 );
130
131 methods.add_async_method(
132 "send_timeout",
133 move |_lua,
134 this: LuaUserDataRef<QueueHandle>,
135 (value, timeout_seconds): (CacheValue, f32)| async move {
136 this.0
137 .sender
138 .send_timeout(value, Duration::from_secs_f32(timeout_seconds))
139 .await
140 .map_err(any_err)?;
141 Ok(())
142 },
143 );
144
145 methods.add_method(
146 "try_send",
147 move |_lua, this: &QueueHandle, value: CacheValue| match this.0.sender.try_send(value) {
148 Ok(()) => Ok(true),
149 Err(_) => Ok(false),
150 },
151 );
152
153 methods.add_async_method(
154 "close",
155 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
156 let mut rx = this.0.receiver.lock().await;
157 Ok(rx.close())
158 },
159 );
160
161 methods.add_method("is_closed", move |_lua, this: &QueueHandle, ()| {
162 Ok(this.0.sender.is_closed())
163 });
164
165 methods.add_async_method(
166 "is_empty",
167 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
168 let rx = this.0.receiver.lock().await;
169 Ok(rx.is_empty())
170 },
171 );
172
173 methods.add_async_method(
174 "len",
175 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
176 let rx = this.0.receiver.lock().await;
177 Ok(rx.len())
178 },
179 );
180
181 methods.add_async_method(
182 "recv",
183 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
184 let mut rx = this.0.receiver.lock().await;
185 Ok(rx.recv().await)
186 },
187 );
188
189 methods.add_async_method(
190 "try_recv",
191 move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
192 let mut rx = this.0.receiver.lock().await;
193 Ok(rx.try_recv())
194 },
195 );
196
197 methods.add_async_method(
198 "recv_many",
199 move |_lua, this: LuaUserDataRef<QueueHandle>, limit: usize| async move {
200 let mut rx = this.0.receiver.lock().await;
201 Ok(rx.recv_many(limit).await)
202 },
203 );
204 }
205}
206
207static QUEUES: LazyLock<DashMap<String, Arc<Queue>>> = LazyLock::new(DashMap::new);
208
209impl Queue {
210 pub fn define_unbounded(name: &str) -> anyhow::Result<QueueHandle> {
211 let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
212 let (sender, receiver) = unbounded_channel();
213 Arc::new(Queue {
214 sender: Arc::new(SenderWrapper::Unbounded(sender)),
215 receiver: Arc::new(Mutex::new(ReceiverWrapper::Unbounded(receiver))),
216 })
217 });
218
219 Ok(QueueHandle(Arc::clone(queue.value())))
220 }
221
222 pub fn define_bounded(name: &str, buffer: usize) -> anyhow::Result<QueueHandle> {
223 let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
224 let (sender, receiver) = channel(buffer);
225 Arc::new(Queue {
226 sender: Arc::new(SenderWrapper::Bounded(sender)),
227 receiver: Arc::new(Mutex::new(ReceiverWrapper::Bounded(receiver))),
228 })
229 });
230
231 Ok(QueueHandle(Arc::clone(queue.value())))
232 }
233}
234
235pub fn register(lua: &Lua) -> anyhow::Result<()> {
236 let kumo_mpsc = get_or_create_sub_module(lua, "mpsc")?;
237
238 kumo_mpsc.set(
239 "define",
240 lua.create_function(|_lua, (name, buffer): (String, Option<usize>)| {
241 match buffer {
242 Some(buffer) => Queue::define_bounded(&name, buffer),
243 None => Queue::define_unbounded(&name),
244 }
245 .map_err(any_err)
246 })?,
247 )?;
248
249 Ok(())
250}