1use config::{any_err, get_or_create_sub_module};
2use futures::stream::FuturesOrdered;
3use futures::StreamExt;
4use mlua::prelude::LuaUserData;
5use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
6use rdkafka::message::{Header, OwnedHeaders};
7use rdkafka::producer::future_producer::Delivery;
8use rdkafka::producer::{FutureProducer, FutureRecord};
9use rdkafka::util::Timeout;
10use rdkafka::ClientConfig;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14use std::time::Duration;
15
16#[derive(Clone)]
17struct Producer {
18 producer: Arc<Mutex<Option<Arc<FutureProducer>>>>,
19}
20
21impl Producer {
22 fn get_producer(&self) -> mlua::Result<Arc<FutureProducer>> {
23 self.producer
24 .lock()
25 .unwrap()
26 .as_ref()
27 .map(Arc::clone)
28 .ok_or_else(|| mlua::Error::external("client was closed"))
29 }
30}
31
32#[derive(Deserialize, Debug)]
33struct Record {
34 topic: String,
36 #[serde(default)]
38 partition: Option<i32>,
39 #[serde(with = "serde_bytes")]
41 payload: Option<Vec<u8>>,
42 #[serde(default)]
44 key: Option<String>,
45
46 #[serde(default)]
48 headers: HashMap<String, String>,
49
50 #[serde(default)]
57 #[serde(with = "duration_serde")]
58 timeout: Option<Duration>,
59}
60
61impl LuaUserData for Producer {
62 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
63 methods.add_async_method("send", |lua, this, value: Value| async move {
64 let record: Record = lua.from_value(value)?;
65
66 let headers = if record.headers.is_empty() {
67 None
68 } else {
69 let mut headers = OwnedHeaders::new();
70 for (key, v) in &record.headers {
71 headers = headers.insert(Header {
72 key,
73 value: Some(v),
74 });
75 }
76 Some(headers)
77 };
78
79 let future_record = FutureRecord {
80 topic: &record.topic,
81 partition: record.partition,
82 payload: record.payload.as_ref(),
83 key: record.key.as_ref(),
84 headers,
85 timestamp: None,
86 };
87
88 let Delivery {
89 partition,
90 offset,
91 timestamp: _,
92 } = this
93 .get_producer()?
94 .send(
95 future_record,
96 Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
97 )
98 .await
99 .map_err(|(code, _msg)| any_err(code))?;
100
101 Ok((partition, offset))
102 });
103
104 methods.add_async_method("send_batch", |lua, this, values: Vec<Value>| async move {
105 let mut tasks = FuturesOrdered::new();
106 let producer = this.get_producer()?;
107
108 for value in values {
109 let record: Record = lua.from_value(value)?;
110
111 let headers = if record.headers.is_empty() {
112 None
113 } else {
114 let mut headers = OwnedHeaders::new();
115 for (key, v) in &record.headers {
116 headers = headers.insert(Header {
117 key,
118 value: Some(v),
119 });
120 }
121 Some(headers)
122 };
123
124 let producer = producer.clone();
125
126 tasks.push_back(tokio::spawn(async move {
127 producer
128 .send(
129 FutureRecord {
130 topic: &record.topic,
131 partition: record.partition,
132 payload: record.payload.as_ref(),
133 key: record.key.as_ref(),
134 headers,
135 timestamp: None,
136 },
137 Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
138 )
139 .await
140 }));
141 }
142
143 let failed_indexes = lua.create_table()?;
144 let errors = lua.create_table()?;
145 let mut index = 1;
146
147 while let Some(result) = tasks.next().await {
148 match result {
149 Ok(Ok(_)) => {}
150 Ok(Err((error, _msg))) => {
151 failed_indexes.push(index)?;
152 errors.push(format!("{error:#}"))?;
153 }
154 Err(error) => {
155 failed_indexes.push(index)?;
156 errors.push(format!("{error:#}"))?;
157 }
158 }
159 index += 1;
160 }
161 Ok((failed_indexes, errors))
162 });
163
164 methods.add_method("close", |_lua, this, _: ()| {
165 this.producer.lock().unwrap().take();
166 Ok(())
167 });
168 }
169}
170
171pub fn register(lua: &Lua) -> anyhow::Result<()> {
172 let kafka_mod = get_or_create_sub_module(lua, "kafka")?;
173
174 kafka_mod.set(
175 "build_producer",
176 lua.create_async_function(|_, config: HashMap<String, String>| async move {
177 let mut builder = ClientConfig::new();
178 for (k, v) in config {
179 builder.set(k, v);
180 }
181
182 let producer = builder.create().map_err(any_err)?;
183
184 Ok(Producer {
185 producer: Arc::new(Mutex::new(Some(Arc::new(producer)))),
186 })
187 })?,
188 )?;
189
190 Ok(())
191}