1use config::{any_err, get_or_create_sub_module};
2use futures::stream::FuturesOrdered;
3use futures::StreamExt;
4use mlua::prelude::LuaUserData;
5use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
6use rdkafka::message::{Header, OwnedHeaders};
7use rdkafka::producer::{FutureProducer, FutureRecord};
8use rdkafka::util::Timeout;
9use rdkafka::ClientConfig;
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14
15#[derive(Clone)]
16struct Producer {
17 producer: Arc<Mutex<Option<Arc<FutureProducer>>>>,
18}
19
20impl Producer {
21 fn get_producer(&self) -> mlua::Result<Arc<FutureProducer>> {
22 self.producer
23 .lock()
24 .unwrap()
25 .as_ref()
26 .map(Arc::clone)
27 .ok_or_else(|| mlua::Error::external("client was closed"))
28 }
29}
30
31#[derive(Deserialize, Debug)]
32struct Record {
33 topic: String,
35 #[serde(default)]
37 partition: Option<i32>,
38 #[serde(with = "serde_bytes")]
40 payload: Option<Vec<u8>>,
41 #[serde(default)]
43 key: Option<String>,
44
45 #[serde(default)]
47 headers: HashMap<String, String>,
48
49 #[serde(default)]
56 #[serde(with = "duration_serde")]
57 timeout: Option<Duration>,
58}
59
60impl LuaUserData for Producer {
61 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
62 methods.add_async_method("send", |lua, this, value: Value| async move {
63 let record: Record = lua.from_value(value)?;
64
65 let headers = if record.headers.is_empty() {
66 None
67 } else {
68 let mut headers = OwnedHeaders::new();
69 for (key, v) in &record.headers {
70 headers = headers.insert(Header {
71 key,
72 value: Some(v),
73 });
74 }
75 Some(headers)
76 };
77
78 let future_record = FutureRecord {
79 topic: &record.topic,
80 partition: record.partition,
81 payload: record.payload.as_ref(),
82 key: record.key.as_ref(),
83 headers,
84 timestamp: None,
85 };
86
87 let (partition, offset) = this
88 .get_producer()?
89 .send(
90 future_record,
91 Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
92 )
93 .await
94 .map_err(|(code, _msg)| any_err(code))?;
95
96 Ok((partition, offset))
97 });
98
99 methods.add_async_method("send_batch", |lua, this, values: Vec<Value>| async move {
100 let mut tasks = FuturesOrdered::new();
101 let producer = this.get_producer()?;
102
103 for value in values {
104 let record: Record = lua.from_value(value)?;
105
106 let headers = if record.headers.is_empty() {
107 None
108 } else {
109 let mut headers = OwnedHeaders::new();
110 for (key, v) in &record.headers {
111 headers = headers.insert(Header {
112 key,
113 value: Some(v),
114 });
115 }
116 Some(headers)
117 };
118
119 let producer = producer.clone();
120
121 tasks.push_back(tokio::spawn(async move {
122 producer
123 .send(
124 FutureRecord {
125 topic: &record.topic,
126 partition: record.partition,
127 payload: record.payload.as_ref(),
128 key: record.key.as_ref(),
129 headers,
130 timestamp: None,
131 },
132 Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
133 )
134 .await
135 }));
136 }
137
138 let failed_indexes = lua.create_table()?;
139 let errors = lua.create_table()?;
140 let mut index = 1;
141
142 while let Some(result) = tasks.next().await {
143 match result {
144 Ok(Ok(_)) => {}
145 Ok(Err((error, _msg))) => {
146 failed_indexes.push(index)?;
147 errors.push(format!("{error:#}"))?;
148 }
149 Err(error) => {
150 failed_indexes.push(index)?;
151 errors.push(format!("{error:#}"))?;
152 }
153 }
154 index += 1;
155 }
156 Ok((failed_indexes, errors))
157 });
158
159 methods.add_method("close", |_lua, this, _: ()| {
160 this.producer.lock().unwrap().take();
161 Ok(())
162 });
163 }
164}
165
166pub fn register(lua: &Lua) -> anyhow::Result<()> {
167 let kafka_mod = get_or_create_sub_module(lua, "kafka")?;
168
169 kafka_mod.set(
170 "build_producer",
171 lua.create_async_function(|_, config: HashMap<String, String>| async move {
172 let mut builder = ClientConfig::new();
173 for (k, v) in config {
174 builder.set(k, v);
175 }
176
177 let producer = builder.create().map_err(any_err)?;
178
179 Ok(Producer {
180 producer: Arc::new(Mutex::new(Some(Arc::new(producer)))),
181 })
182 })?,
183 )?;
184
185 Ok(())
186}