1use crate::{Spool, SpoolEntry, SpoolId};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use flume::Sender;
5use prometheus::IntGaugeVec;
6use rocksdb::perf::get_memory_usage_stats;
7use rocksdb::{
8 DBCompressionType, ErrorKind, IteratorMode, LogLevel, Options, WriteBatch, WriteOptions, DB,
9};
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::sync::{Arc, LazyLock, Weak};
13use std::time::{Duration, Instant};
14use tokio::runtime::Handle;
15use tokio::sync::Semaphore;
16use tokio::time::timeout_at;
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct RocksSpoolParams {
20 pub increase_parallelism: Option<i32>,
21
22 pub optimize_level_style_compaction: Option<usize>,
23 pub optimize_universal_style_compaction: Option<usize>,
24 #[serde(default)]
25 pub paranoid_checks: bool,
26 #[serde(default)]
27 pub compression_type: DBCompressionTypeDef,
28
29 pub compaction_readahead_size: Option<usize>,
33
34 #[serde(default)]
35 pub level_compaction_dynamic_level_bytes: bool,
36
37 #[serde(default)]
38 pub max_open_files: Option<usize>,
39
40 #[serde(default)]
41 pub log_level: LogLevelDef,
42
43 #[serde(default)]
46 pub memtable_huge_page_size: Option<usize>,
47
48 #[serde(
49 with = "duration_serde",
50 default = "RocksSpoolParams::default_log_file_time_to_roll"
51 )]
52 pub log_file_time_to_roll: Duration,
53
54 #[serde(
55 with = "duration_serde",
56 default = "RocksSpoolParams::default_obsolete_files_period"
57 )]
58 pub obsolete_files_period: Duration,
59
60 #[serde(default)]
61 pub limit_concurrent_stores: Option<usize>,
62 #[serde(default)]
63 pub limit_concurrent_loads: Option<usize>,
64 #[serde(default)]
65 pub limit_concurrent_removes: Option<usize>,
66}
67
68impl Default for RocksSpoolParams {
69 fn default() -> Self {
70 Self {
71 increase_parallelism: None,
72 optimize_level_style_compaction: None,
73 optimize_universal_style_compaction: None,
74 paranoid_checks: false,
75 compression_type: DBCompressionTypeDef::default(),
76 compaction_readahead_size: None,
77 level_compaction_dynamic_level_bytes: false,
78 max_open_files: None,
79 log_level: LogLevelDef::default(),
80 memtable_huge_page_size: None,
81 log_file_time_to_roll: Self::default_log_file_time_to_roll(),
82 obsolete_files_period: Self::default_obsolete_files_period(),
83 limit_concurrent_stores: None,
84 limit_concurrent_loads: None,
85 limit_concurrent_removes: None,
86 }
87 }
88}
89
90impl RocksSpoolParams {
91 fn default_log_file_time_to_roll() -> Duration {
92 Duration::from_secs(86400)
93 }
94
95 fn default_obsolete_files_period() -> Duration {
96 Duration::from_secs(6 * 60 * 60)
97 }
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub enum DBCompressionTypeDef {
102 None,
103 Snappy,
104 Zlib,
105 Bz2,
106 Lz4,
107 Lz4hc,
108 Zstd,
109}
110
111impl From<DBCompressionTypeDef> for DBCompressionType {
112 fn from(val: DBCompressionTypeDef) -> Self {
113 match val {
114 DBCompressionTypeDef::None => DBCompressionType::None,
115 DBCompressionTypeDef::Snappy => DBCompressionType::Snappy,
116 DBCompressionTypeDef::Zlib => DBCompressionType::Zlib,
117 DBCompressionTypeDef::Bz2 => DBCompressionType::Bz2,
118 DBCompressionTypeDef::Lz4 => DBCompressionType::Lz4,
119 DBCompressionTypeDef::Lz4hc => DBCompressionType::Lz4hc,
120 DBCompressionTypeDef::Zstd => DBCompressionType::Zstd,
121 }
122 }
123}
124
125impl Default for DBCompressionTypeDef {
126 fn default() -> Self {
127 Self::Snappy
128 }
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132pub enum LogLevelDef {
133 Debug,
134 Info,
135 Warn,
136 Error,
137 Fatal,
138 Header,
139}
140
141impl Default for LogLevelDef {
142 fn default() -> Self {
143 Self::Info
144 }
145}
146
147impl From<LogLevelDef> for LogLevel {
148 fn from(val: LogLevelDef) -> Self {
149 match val {
150 LogLevelDef::Debug => LogLevel::Debug,
151 LogLevelDef::Info => LogLevel::Info,
152 LogLevelDef::Warn => LogLevel::Warn,
153 LogLevelDef::Error => LogLevel::Error,
154 LogLevelDef::Fatal => LogLevel::Fatal,
155 LogLevelDef::Header => LogLevel::Header,
156 }
157 }
158}
159
160pub struct RocksSpool {
161 db: Arc<DB>,
162 runtime: Handle,
163 limit_concurrent_stores: Option<Arc<Semaphore>>,
164 limit_concurrent_loads: Option<Arc<Semaphore>>,
165 limit_concurrent_removes: Option<Arc<Semaphore>>,
166}
167
168impl RocksSpool {
169 pub fn new(
170 path: &Path,
171 flush: bool,
172 params: Option<RocksSpoolParams>,
173 runtime: Handle,
174 ) -> anyhow::Result<Self> {
175 let mut opts = Options::default();
176 opts.set_use_fsync(flush);
177 opts.create_if_missing(true);
178 opts.set_keep_log_file_num(10);
180
181 let p = params.unwrap_or_default();
182 if let Some(i) = p.increase_parallelism {
183 opts.increase_parallelism(i);
184 }
185 if let Some(i) = p.optimize_level_style_compaction {
186 opts.optimize_level_style_compaction(i);
187 }
188 if let Some(i) = p.optimize_universal_style_compaction {
189 opts.optimize_universal_style_compaction(i);
190 }
191 if let Some(i) = p.compaction_readahead_size {
192 opts.set_compaction_readahead_size(i);
193 }
194 if let Some(i) = p.max_open_files {
195 opts.set_max_open_files(i as _);
196 }
197 if let Some(i) = p.memtable_huge_page_size {
198 opts.set_memtable_huge_page_size(i);
199 }
200 opts.set_paranoid_checks(p.paranoid_checks);
201 opts.set_level_compaction_dynamic_level_bytes(p.level_compaction_dynamic_level_bytes);
202 opts.set_compression_type(p.compression_type.into());
203 opts.set_log_level(p.log_level.into());
204 opts.set_log_file_time_to_roll(p.log_file_time_to_roll.as_secs() as usize);
205 opts.set_delete_obsolete_files_period_micros(p.obsolete_files_period.as_micros() as u64);
206
207 let limit_concurrent_stores = p
208 .limit_concurrent_stores
209 .map(|n| Arc::new(Semaphore::new(n)));
210 let limit_concurrent_loads = p
211 .limit_concurrent_loads
212 .map(|n| Arc::new(Semaphore::new(n)));
213 let limit_concurrent_removes = p
214 .limit_concurrent_removes
215 .map(|n| Arc::new(Semaphore::new(n)));
216
217 let db = Arc::new(DB::open(&opts, path)?);
218
219 {
220 let weak = Arc::downgrade(&db);
221 tokio::spawn(metrics_monitor(weak, format!("{}", path.display())));
222 }
223
224 Ok(Self {
225 db,
226 runtime,
227 limit_concurrent_stores,
228 limit_concurrent_loads,
229 limit_concurrent_removes,
230 })
231 }
232}
233
234#[async_trait]
235impl Spool for RocksSpool {
236 async fn load(&self, id: SpoolId) -> anyhow::Result<Vec<u8>> {
237 let permit = match self.limit_concurrent_loads.clone() {
238 Some(s) => Some(s.acquire_owned().await?),
239 None => None,
240 };
241 let db = self.db.clone();
242 tokio::task::Builder::new()
243 .name("rocksdb load")
244 .spawn_blocking_on(
245 move || {
246 let result = db
247 .get(id.as_bytes())?
248 .ok_or_else(|| anyhow::anyhow!("no such key {id}"))?;
249 drop(permit);
250 Ok(result)
251 },
252 &self.runtime,
253 )?
254 .await?
255 }
256
257 async fn store(
258 &self,
259 id: SpoolId,
260 data: Arc<Box<[u8]>>,
261 force_sync: bool,
262 deadline: Option<Instant>,
263 ) -> anyhow::Result<()> {
264 let mut opts = WriteOptions::default();
265 opts.set_sync(force_sync);
266 opts.set_no_slowdown(true);
267 let mut batch = WriteBatch::default();
268 batch.put(id.as_bytes(), &*data);
269
270 match self.db.write_opt(batch, &opts) {
271 Ok(()) => Ok(()),
272 Err(err) if err.kind() == ErrorKind::Incomplete => {
273 let permit = match (self.limit_concurrent_stores.clone(), deadline) {
274 (Some(s), Some(deadline)) => {
275 Some(timeout_at(deadline.into(), s.acquire_owned()).await??)
276 }
277 (Some(s), None) => Some(s.acquire_owned().await?),
278 (None, _) => None,
279 };
280 let db = self.db.clone();
281 tokio::task::Builder::new()
282 .name("rocksdb store")
283 .spawn_blocking_on(
284 move || {
285 opts.set_no_slowdown(false);
286 let mut batch = WriteBatch::default();
287 batch.put(id.as_bytes(), &*data);
288 let result = db.write_opt(batch, &opts)?;
289 drop(permit);
290 Ok(result)
291 },
292 &self.runtime,
293 )?
294 .await?
295 }
296 Err(err) => Err(err.into()),
297 }
298 }
299
300 async fn remove(&self, id: SpoolId) -> anyhow::Result<()> {
301 let mut opts = WriteOptions::default();
302 opts.set_no_slowdown(true);
303 let mut batch = WriteBatch::default();
304 batch.delete(id.as_bytes());
305
306 match self.db.write_opt(batch, &opts) {
307 Ok(()) => Ok(()),
308 Err(err) if err.kind() == ErrorKind::Incomplete => {
309 let permit = match self.limit_concurrent_removes.clone() {
310 Some(s) => Some(s.acquire_owned().await?),
311 None => None,
312 };
313 let db = self.db.clone();
314 tokio::task::Builder::new()
315 .name("rocksdb remove")
316 .spawn_blocking_on(
317 move || {
318 opts.set_no_slowdown(false);
319 let mut batch = WriteBatch::default();
320 batch.delete(id.as_bytes());
321 let result = db.write_opt(batch, &opts)?;
322 drop(permit);
323 Ok(result)
324 },
325 &self.runtime,
326 )?
327 .await?
328 }
329 Err(err) => Err(err.into()),
330 }
331 }
332
333 async fn cleanup(&self) -> anyhow::Result<()> {
334 Ok(())
335 }
336
337 async fn shutdown(&self) -> anyhow::Result<()> {
338 let db = self.db.clone();
339 tokio::task::spawn_blocking(move || db.cancel_all_background_work(true)).await?;
340 Ok(())
341 }
342
343 async fn advise_low_memory(&self) -> anyhow::Result<isize> {
344 let db = self.db.clone();
345 tokio::task::spawn_blocking(move || {
346 let usage_before = match get_memory_usage_stats(Some(&[&db]), None) {
347 Ok(stats) => {
348 let stats: Stats = stats.into();
349 tracing::debug!("pre-flush: {stats:#?}");
350 stats.total()
351 }
352 Err(err) => {
353 tracing::error!("error getting stats: {err:#}");
354 0
355 }
356 };
357
358 if let Err(err) = db.flush() {
359 tracing::error!("error flushing memory: {err:#}");
360 }
361
362 let usage_after = match get_memory_usage_stats(Some(&[&db]), None) {
363 Ok(stats) => {
364 let stats: Stats = stats.into();
365 tracing::debug!("post-flush: {stats:#?}");
366 stats.total()
367 }
368 Err(err) => {
369 tracing::error!("error getting stats: {err:#}");
370 0
371 }
372 };
373
374 Ok(usage_before - usage_after)
375 })
376 .await?
377 }
378
379 fn enumerate(
380 &self,
381 sender: Sender<SpoolEntry>,
382 start_time: DateTime<Utc>,
383 ) -> anyhow::Result<()> {
384 let db = Arc::clone(&self.db);
385 tokio::task::Builder::new()
386 .name("rocksdb enumerate")
387 .spawn_blocking_on(
388 move || {
389 let iter = db.iterator(IteratorMode::Start);
390 for entry in iter {
391 let (key, value) = entry?;
392 let id = SpoolId::from_slice(&key)
393 .ok_or_else(|| anyhow::anyhow!("invalid spool id {key:?}"))?;
394
395 if id.created() >= start_time {
396 continue;
400 }
401
402 sender
403 .send(SpoolEntry::Item {
404 id,
405 data: value.to_vec(),
406 })
407 .map_err(|err| {
408 anyhow::anyhow!("failed to send SpoolEntry for {id}: {err:#}")
409 })?;
410 }
411 Ok::<(), anyhow::Error>(())
412 },
413 &self.runtime,
414 )?;
415 Ok(())
416 }
417}
418
419#[cfg(test)]
420mod test {
421 use super::*;
422
423 #[tokio::test]
424 async fn rocks_spool() -> anyhow::Result<()> {
425 let location = tempfile::tempdir()?;
426 let spool = RocksSpool::new(location.path(), false, None, Handle::current())?;
427
428 {
429 let id1 = SpoolId::new();
430
431 assert_eq!(
433 format!("{:#}", spool.load(id1).await.unwrap_err()),
434 format!("no such key {id1}")
435 );
436 }
437
438 let mut ids = vec![];
440 for i in 0..100 {
441 let id = SpoolId::new();
442 spool
443 .store(
444 id,
445 Arc::new(format!("I am {i}").as_bytes().to_vec().into_boxed_slice()),
446 false,
447 None,
448 )
449 .await?;
450 ids.push(id);
451 }
452
453 for (i, &id) in ids.iter().enumerate() {
455 let data = spool.load(id).await?;
456 let text = String::from_utf8(data)?;
457 assert_eq!(text, format!("I am {i}"));
458 }
459
460 {
461 let (tx, rx) = flume::bounded(32);
463 spool.enumerate(tx, Utc::now())?;
464 let mut count = 0;
465
466 while let Ok(item) = rx.recv_async().await {
467 match item {
468 SpoolEntry::Item { id, data } => {
469 let i = ids
470 .iter()
471 .position(|&item| item == id)
472 .ok_or_else(|| anyhow::anyhow!("{id} not found in ids!"))?;
473
474 let text = String::from_utf8(data)?;
475 assert_eq!(text, format!("I am {i}"));
476
477 spool.remove(id).await?;
478 assert_eq!(
480 format!("{:#}", spool.load(id).await.unwrap_err()),
481 format!("no such key {id}")
482 );
483 count += 1;
484 }
485 SpoolEntry::Corrupt { id, error } => {
486 anyhow::bail!("Corrupt: {id}: {error}");
487 }
488 }
489 }
490
491 assert_eq!(count, 100);
492 }
493
494 for _ in 0..2 {
500 let (tx, rx) = flume::bounded(32);
502 spool.enumerate(tx, Utc::now())?;
503 let mut unexpected = vec![];
504
505 while let Ok(item) = rx.recv_async().await {
506 match item {
507 SpoolEntry::Item { id, .. } | SpoolEntry::Corrupt { id, .. } => {
508 unexpected.push(id)
509 }
510 }
511 }
512
513 assert_eq!(unexpected.len(), 0);
514 }
515
516 Ok(())
517 }
518}
519
520#[allow(unused)]
522#[derive(Debug)]
523struct Stats {
524 pub mem_table_total: u64,
525 pub mem_table_unflushed: u64,
526 pub mem_table_readers_total: u64,
527 pub cache_total: u64,
528}
529
530impl Stats {
531 fn total(&self) -> isize {
532 (self.mem_table_total + self.mem_table_readers_total + self.cache_total) as isize
533 }
534}
535
536impl From<rocksdb::perf::MemoryUsageStats> for Stats {
537 fn from(s: rocksdb::perf::MemoryUsageStats) -> Self {
538 Self {
539 mem_table_total: s.mem_table_total,
540 mem_table_unflushed: s.mem_table_unflushed,
541 mem_table_readers_total: s.mem_table_readers_total,
542 cache_total: s.cache_total,
543 }
544 }
545}
546
547static MEM_TABLE_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
548 prometheus::register_int_gauge_vec!(
549 "rocks_spool_mem_table_total",
550 "Approximate memory usage of all the mem-tables",
551 &["path"]
552 )
553 .unwrap()
554});
555static MEM_TABLE_UNFLUSHED: LazyLock<IntGaugeVec> = LazyLock::new(|| {
556 prometheus::register_int_gauge_vec!(
557 "rocks_spool_mem_table_unflushed",
558 "Approximate memory usage of un-flushed mem-tables",
559 &["path"]
560 )
561 .unwrap()
562});
563static MEM_TABLE_READERS_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
564 prometheus::register_int_gauge_vec!(
565 "rocks_spool_mem_table_readers_total",
566 "Approximate memory usage of all the table readers",
567 &["path"]
568 )
569 .unwrap()
570});
571static CACHE_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
572 prometheus::register_int_gauge_vec!(
573 "rocks_spool_cache_total",
574 "Approximate memory usage by cache",
575 &["path"]
576 )
577 .unwrap()
578});
579
580async fn metrics_monitor(db: Weak<DB>, path: String) {
581 let mem_table_total = MEM_TABLE_TOTAL
582 .get_metric_with_label_values(&[path.as_str()])
583 .unwrap();
584 let mem_table_unflushed = MEM_TABLE_UNFLUSHED
585 .get_metric_with_label_values(&[path.as_str()])
586 .unwrap();
587 let mem_table_readers_total = MEM_TABLE_READERS_TOTAL
588 .get_metric_with_label_values(&[path.as_str()])
589 .unwrap();
590 let cache_total = CACHE_TOTAL
591 .get_metric_with_label_values(&[path.as_str()])
592 .unwrap();
593
594 loop {
595 match db.upgrade() {
596 Some(db) => {
597 match get_memory_usage_stats(Some(&[&db]), None) {
598 Ok(stats) => {
599 mem_table_total.set(stats.mem_table_total as i64);
600 mem_table_unflushed.set(stats.mem_table_unflushed as i64);
601 mem_table_readers_total.set(stats.mem_table_readers_total as i64);
602 cache_total.set(stats.cache_total as i64);
603 }
604 Err(err) => {
605 tracing::error!("error getting stats: {err:#}");
606 }
607 };
608 }
609 None => {
610 return;
612 }
613 }
614 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
615 }
616}