spool/
rocks.rs

1use crate::{Spool, SpoolEntry, SpoolId};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use flume::Sender;
5use prometheus::IntGaugeVec;
6use rocksdb::perf::get_memory_usage_stats;
7use rocksdb::{
8    DBCompressionType, ErrorKind, IteratorMode, LogLevel, Options, WriteBatch, WriteOptions, DB,
9};
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::sync::{Arc, LazyLock, Weak};
13use std::time::{Duration, Instant};
14use tokio::runtime::Handle;
15use tokio::sync::Semaphore;
16use tokio::time::timeout_at;
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct RocksSpoolParams {
20    pub increase_parallelism: Option<i32>,
21
22    pub optimize_level_style_compaction: Option<usize>,
23    pub optimize_universal_style_compaction: Option<usize>,
24    #[serde(default)]
25    pub paranoid_checks: bool,
26    #[serde(default)]
27    pub compression_type: DBCompressionTypeDef,
28
29    /// If non-zero, we perform bigger reads when doing compaction. If you’re running RocksDB on
30    /// spinning disks, you should set this to at least 2MB. That way RocksDB’s compaction is doing
31    /// sequential instead of random reads
32    pub compaction_readahead_size: Option<usize>,
33
34    #[serde(default)]
35    pub level_compaction_dynamic_level_bytes: bool,
36
37    #[serde(default)]
38    pub max_open_files: Option<usize>,
39
40    #[serde(default)]
41    pub log_level: LogLevelDef,
42
43    /// See:
44    /// <https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#method.set_memtable_huge_page_size>
45    #[serde(default)]
46    pub memtable_huge_page_size: Option<usize>,
47
48    #[serde(
49        with = "duration_serde",
50        default = "RocksSpoolParams::default_log_file_time_to_roll"
51    )]
52    pub log_file_time_to_roll: Duration,
53
54    #[serde(
55        with = "duration_serde",
56        default = "RocksSpoolParams::default_obsolete_files_period"
57    )]
58    pub obsolete_files_period: Duration,
59
60    #[serde(default)]
61    pub limit_concurrent_stores: Option<usize>,
62    #[serde(default)]
63    pub limit_concurrent_loads: Option<usize>,
64    #[serde(default)]
65    pub limit_concurrent_removes: Option<usize>,
66}
67
68impl Default for RocksSpoolParams {
69    fn default() -> Self {
70        Self {
71            increase_parallelism: None,
72            optimize_level_style_compaction: None,
73            optimize_universal_style_compaction: None,
74            paranoid_checks: false,
75            compression_type: DBCompressionTypeDef::default(),
76            compaction_readahead_size: None,
77            level_compaction_dynamic_level_bytes: false,
78            max_open_files: None,
79            log_level: LogLevelDef::default(),
80            memtable_huge_page_size: None,
81            log_file_time_to_roll: Self::default_log_file_time_to_roll(),
82            obsolete_files_period: Self::default_obsolete_files_period(),
83            limit_concurrent_stores: None,
84            limit_concurrent_loads: None,
85            limit_concurrent_removes: None,
86        }
87    }
88}
89
90impl RocksSpoolParams {
91    fn default_log_file_time_to_roll() -> Duration {
92        Duration::from_secs(86400)
93    }
94
95    fn default_obsolete_files_period() -> Duration {
96        Duration::from_secs(6 * 60 * 60)
97    }
98}
99
100#[derive(Serialize, Deserialize, Debug)]
101pub enum DBCompressionTypeDef {
102    None,
103    Snappy,
104    Zlib,
105    Bz2,
106    Lz4,
107    Lz4hc,
108    Zstd,
109}
110
111impl From<DBCompressionTypeDef> for DBCompressionType {
112    fn from(val: DBCompressionTypeDef) -> Self {
113        match val {
114            DBCompressionTypeDef::None => DBCompressionType::None,
115            DBCompressionTypeDef::Snappy => DBCompressionType::Snappy,
116            DBCompressionTypeDef::Zlib => DBCompressionType::Zlib,
117            DBCompressionTypeDef::Bz2 => DBCompressionType::Bz2,
118            DBCompressionTypeDef::Lz4 => DBCompressionType::Lz4,
119            DBCompressionTypeDef::Lz4hc => DBCompressionType::Lz4hc,
120            DBCompressionTypeDef::Zstd => DBCompressionType::Zstd,
121        }
122    }
123}
124
125impl Default for DBCompressionTypeDef {
126    fn default() -> Self {
127        Self::Snappy
128    }
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132pub enum LogLevelDef {
133    Debug,
134    Info,
135    Warn,
136    Error,
137    Fatal,
138    Header,
139}
140
141impl Default for LogLevelDef {
142    fn default() -> Self {
143        Self::Info
144    }
145}
146
147impl From<LogLevelDef> for LogLevel {
148    fn from(val: LogLevelDef) -> Self {
149        match val {
150            LogLevelDef::Debug => LogLevel::Debug,
151            LogLevelDef::Info => LogLevel::Info,
152            LogLevelDef::Warn => LogLevel::Warn,
153            LogLevelDef::Error => LogLevel::Error,
154            LogLevelDef::Fatal => LogLevel::Fatal,
155            LogLevelDef::Header => LogLevel::Header,
156        }
157    }
158}
159
160pub struct RocksSpool {
161    db: Arc<DB>,
162    runtime: Handle,
163    limit_concurrent_stores: Option<Arc<Semaphore>>,
164    limit_concurrent_loads: Option<Arc<Semaphore>>,
165    limit_concurrent_removes: Option<Arc<Semaphore>>,
166}
167
168impl RocksSpool {
169    pub fn new(
170        path: &Path,
171        flush: bool,
172        params: Option<RocksSpoolParams>,
173        runtime: Handle,
174    ) -> anyhow::Result<Self> {
175        let mut opts = Options::default();
176        opts.set_use_fsync(flush);
177        opts.create_if_missing(true);
178        // The default is 1000, which is a bit high
179        opts.set_keep_log_file_num(10);
180
181        let p = params.unwrap_or_default();
182        if let Some(i) = p.increase_parallelism {
183            opts.increase_parallelism(i);
184        }
185        if let Some(i) = p.optimize_level_style_compaction {
186            opts.optimize_level_style_compaction(i);
187        }
188        if let Some(i) = p.optimize_universal_style_compaction {
189            opts.optimize_universal_style_compaction(i);
190        }
191        if let Some(i) = p.compaction_readahead_size {
192            opts.set_compaction_readahead_size(i);
193        }
194        if let Some(i) = p.max_open_files {
195            opts.set_max_open_files(i as _);
196        }
197        if let Some(i) = p.memtable_huge_page_size {
198            opts.set_memtable_huge_page_size(i);
199        }
200        opts.set_paranoid_checks(p.paranoid_checks);
201        opts.set_level_compaction_dynamic_level_bytes(p.level_compaction_dynamic_level_bytes);
202        opts.set_compression_type(p.compression_type.into());
203        opts.set_log_level(p.log_level.into());
204        opts.set_log_file_time_to_roll(p.log_file_time_to_roll.as_secs() as usize);
205        opts.set_delete_obsolete_files_period_micros(p.obsolete_files_period.as_micros() as u64);
206
207        let limit_concurrent_stores = p
208            .limit_concurrent_stores
209            .map(|n| Arc::new(Semaphore::new(n)));
210        let limit_concurrent_loads = p
211            .limit_concurrent_loads
212            .map(|n| Arc::new(Semaphore::new(n)));
213        let limit_concurrent_removes = p
214            .limit_concurrent_removes
215            .map(|n| Arc::new(Semaphore::new(n)));
216
217        let db = Arc::new(DB::open(&opts, path)?);
218
219        {
220            let weak = Arc::downgrade(&db);
221            tokio::spawn(metrics_monitor(weak, format!("{}", path.display())));
222        }
223
224        Ok(Self {
225            db,
226            runtime,
227            limit_concurrent_stores,
228            limit_concurrent_loads,
229            limit_concurrent_removes,
230        })
231    }
232}
233
234#[async_trait]
235impl Spool for RocksSpool {
236    async fn load(&self, id: SpoolId) -> anyhow::Result<Vec<u8>> {
237        let permit = match self.limit_concurrent_loads.clone() {
238            Some(s) => Some(s.acquire_owned().await?),
239            None => None,
240        };
241        let db = self.db.clone();
242        tokio::task::Builder::new()
243            .name("rocksdb load")
244            .spawn_blocking_on(
245                move || {
246                    let result = db
247                        .get(id.as_bytes())?
248                        .ok_or_else(|| anyhow::anyhow!("no such key {id}"))?;
249                    drop(permit);
250                    Ok(result)
251                },
252                &self.runtime,
253            )?
254            .await?
255    }
256
257    async fn store(
258        &self,
259        id: SpoolId,
260        data: Arc<Box<[u8]>>,
261        force_sync: bool,
262        deadline: Option<Instant>,
263    ) -> anyhow::Result<()> {
264        let mut opts = WriteOptions::default();
265        opts.set_sync(force_sync);
266        opts.set_no_slowdown(true);
267        let mut batch = WriteBatch::default();
268        batch.put(id.as_bytes(), &*data);
269
270        match self.db.write_opt(batch, &opts) {
271            Ok(()) => Ok(()),
272            Err(err) if err.kind() == ErrorKind::Incomplete => {
273                let permit = match (self.limit_concurrent_stores.clone(), deadline) {
274                    (Some(s), Some(deadline)) => {
275                        Some(timeout_at(deadline.into(), s.acquire_owned()).await??)
276                    }
277                    (Some(s), None) => Some(s.acquire_owned().await?),
278                    (None, _) => None,
279                };
280                let db = self.db.clone();
281                tokio::task::Builder::new()
282                    .name("rocksdb store")
283                    .spawn_blocking_on(
284                        move || {
285                            opts.set_no_slowdown(false);
286                            let mut batch = WriteBatch::default();
287                            batch.put(id.as_bytes(), &*data);
288                            let result = db.write_opt(batch, &opts)?;
289                            drop(permit);
290                            Ok(result)
291                        },
292                        &self.runtime,
293                    )?
294                    .await?
295            }
296            Err(err) => Err(err.into()),
297        }
298    }
299
300    async fn remove(&self, id: SpoolId) -> anyhow::Result<()> {
301        let mut opts = WriteOptions::default();
302        opts.set_no_slowdown(true);
303        let mut batch = WriteBatch::default();
304        batch.delete(id.as_bytes());
305
306        match self.db.write_opt(batch, &opts) {
307            Ok(()) => Ok(()),
308            Err(err) if err.kind() == ErrorKind::Incomplete => {
309                let permit = match self.limit_concurrent_removes.clone() {
310                    Some(s) => Some(s.acquire_owned().await?),
311                    None => None,
312                };
313                let db = self.db.clone();
314                tokio::task::Builder::new()
315                    .name("rocksdb remove")
316                    .spawn_blocking_on(
317                        move || {
318                            opts.set_no_slowdown(false);
319                            let mut batch = WriteBatch::default();
320                            batch.delete(id.as_bytes());
321                            let result = db.write_opt(batch, &opts)?;
322                            drop(permit);
323                            Ok(result)
324                        },
325                        &self.runtime,
326                    )?
327                    .await?
328            }
329            Err(err) => Err(err.into()),
330        }
331    }
332
333    async fn cleanup(&self) -> anyhow::Result<()> {
334        Ok(())
335    }
336
337    async fn shutdown(&self) -> anyhow::Result<()> {
338        let db = self.db.clone();
339        tokio::task::spawn_blocking(move || db.cancel_all_background_work(true)).await?;
340        Ok(())
341    }
342
343    async fn advise_low_memory(&self) -> anyhow::Result<isize> {
344        let db = self.db.clone();
345        tokio::task::spawn_blocking(move || {
346            let usage_before = match get_memory_usage_stats(Some(&[&db]), None) {
347                Ok(stats) => {
348                    let stats: Stats = stats.into();
349                    tracing::debug!("pre-flush: {stats:#?}");
350                    stats.total()
351                }
352                Err(err) => {
353                    tracing::error!("error getting stats: {err:#}");
354                    0
355                }
356            };
357
358            if let Err(err) = db.flush() {
359                tracing::error!("error flushing memory: {err:#}");
360            }
361
362            let usage_after = match get_memory_usage_stats(Some(&[&db]), None) {
363                Ok(stats) => {
364                    let stats: Stats = stats.into();
365                    tracing::debug!("post-flush: {stats:#?}");
366                    stats.total()
367                }
368                Err(err) => {
369                    tracing::error!("error getting stats: {err:#}");
370                    0
371                }
372            };
373
374            Ok(usage_before - usage_after)
375        })
376        .await?
377    }
378
379    fn enumerate(
380        &self,
381        sender: Sender<SpoolEntry>,
382        start_time: DateTime<Utc>,
383    ) -> anyhow::Result<()> {
384        let db = Arc::clone(&self.db);
385        tokio::task::Builder::new()
386            .name("rocksdb enumerate")
387            .spawn_blocking_on(
388                move || {
389                    let iter = db.iterator(IteratorMode::Start);
390                    for entry in iter {
391                        let (key, value) = entry?;
392                        let id = SpoolId::from_slice(&key)
393                            .ok_or_else(|| anyhow::anyhow!("invalid spool id {key:?}"))?;
394
395                        if id.created() >= start_time {
396                            // Entries created since we started must have
397                            // landed there after we started and are thus
398                            // not eligible for discovery via enumeration
399                            continue;
400                        }
401
402                        sender
403                            .send(SpoolEntry::Item {
404                                id,
405                                data: value.to_vec(),
406                            })
407                            .map_err(|err| {
408                                anyhow::anyhow!("failed to send SpoolEntry for {id}: {err:#}")
409                            })?;
410                    }
411                    Ok::<(), anyhow::Error>(())
412                },
413                &self.runtime,
414            )?;
415        Ok(())
416    }
417}
418
419#[cfg(test)]
420mod test {
421    use super::*;
422
423    #[tokio::test]
424    async fn rocks_spool() -> anyhow::Result<()> {
425        let location = tempfile::tempdir()?;
426        let spool = RocksSpool::new(location.path(), false, None, Handle::current())?;
427
428        {
429            let id1 = SpoolId::new();
430
431            // Can't load an entry that doesn't exist
432            assert_eq!(
433                format!("{:#}", spool.load(id1).await.unwrap_err()),
434                format!("no such key {id1}")
435            );
436        }
437
438        // Insert some entries
439        let mut ids = vec![];
440        for i in 0..100 {
441            let id = SpoolId::new();
442            spool
443                .store(
444                    id,
445                    Arc::new(format!("I am {i}").as_bytes().to_vec().into_boxed_slice()),
446                    false,
447                    None,
448                )
449                .await?;
450            ids.push(id);
451        }
452
453        // Verify that we can load those entries
454        for (i, &id) in ids.iter().enumerate() {
455            let data = spool.load(id).await?;
456            let text = String::from_utf8(data)?;
457            assert_eq!(text, format!("I am {i}"));
458        }
459
460        {
461            // Verify that we can enumerate them
462            let (tx, rx) = flume::bounded(32);
463            spool.enumerate(tx, Utc::now())?;
464            let mut count = 0;
465
466            while let Ok(item) = rx.recv_async().await {
467                match item {
468                    SpoolEntry::Item { id, data } => {
469                        let i = ids
470                            .iter()
471                            .position(|&item| item == id)
472                            .ok_or_else(|| anyhow::anyhow!("{id} not found in ids!"))?;
473
474                        let text = String::from_utf8(data)?;
475                        assert_eq!(text, format!("I am {i}"));
476
477                        spool.remove(id).await?;
478                        // Can't load an entry that we just removed
479                        assert_eq!(
480                            format!("{:#}", spool.load(id).await.unwrap_err()),
481                            format!("no such key {id}")
482                        );
483                        count += 1;
484                    }
485                    SpoolEntry::Corrupt { id, error } => {
486                        anyhow::bail!("Corrupt: {id}: {error}");
487                    }
488                }
489            }
490
491            assert_eq!(count, 100);
492        }
493
494        // Now that we've removed the files, try enumerating again.
495        // We expect to receive no entries.
496        // Do it a couple of times to verify that none of the cleanup
497        // stuff that happens in enumerate breaks the directory
498        // structure
499        for _ in 0..2 {
500            // Verify that we can enumerate them
501            let (tx, rx) = flume::bounded(32);
502            spool.enumerate(tx, Utc::now())?;
503            let mut unexpected = vec![];
504
505            while let Ok(item) = rx.recv_async().await {
506                match item {
507                    SpoolEntry::Item { id, .. } | SpoolEntry::Corrupt { id, .. } => {
508                        unexpected.push(id)
509                    }
510                }
511            }
512
513            assert_eq!(unexpected.len(), 0);
514        }
515
516        Ok(())
517    }
518}
519
520/// The rocksdb type doesn't impl Debug, so we get to do it
521#[allow(unused)]
522#[derive(Debug)]
523struct Stats {
524    pub mem_table_total: u64,
525    pub mem_table_unflushed: u64,
526    pub mem_table_readers_total: u64,
527    pub cache_total: u64,
528}
529
530impl Stats {
531    fn total(&self) -> isize {
532        (self.mem_table_total + self.mem_table_readers_total + self.cache_total) as isize
533    }
534}
535
536impl From<rocksdb::perf::MemoryUsageStats> for Stats {
537    fn from(s: rocksdb::perf::MemoryUsageStats) -> Self {
538        Self {
539            mem_table_total: s.mem_table_total,
540            mem_table_unflushed: s.mem_table_unflushed,
541            mem_table_readers_total: s.mem_table_readers_total,
542            cache_total: s.cache_total,
543        }
544    }
545}
546
547static MEM_TABLE_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
548    prometheus::register_int_gauge_vec!(
549        "rocks_spool_mem_table_total",
550        "Approximate memory usage of all the mem-tables",
551        &["path"]
552    )
553    .unwrap()
554});
555static MEM_TABLE_UNFLUSHED: LazyLock<IntGaugeVec> = LazyLock::new(|| {
556    prometheus::register_int_gauge_vec!(
557        "rocks_spool_mem_table_unflushed",
558        "Approximate memory usage of un-flushed mem-tables",
559        &["path"]
560    )
561    .unwrap()
562});
563static MEM_TABLE_READERS_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
564    prometheus::register_int_gauge_vec!(
565        "rocks_spool_mem_table_readers_total",
566        "Approximate memory usage of all the table readers",
567        &["path"]
568    )
569    .unwrap()
570});
571static CACHE_TOTAL: LazyLock<IntGaugeVec> = LazyLock::new(|| {
572    prometheus::register_int_gauge_vec!(
573        "rocks_spool_cache_total",
574        "Approximate memory usage by cache",
575        &["path"]
576    )
577    .unwrap()
578});
579
580async fn metrics_monitor(db: Weak<DB>, path: String) {
581    let mem_table_total = MEM_TABLE_TOTAL
582        .get_metric_with_label_values(&[path.as_str()])
583        .unwrap();
584    let mem_table_unflushed = MEM_TABLE_UNFLUSHED
585        .get_metric_with_label_values(&[path.as_str()])
586        .unwrap();
587    let mem_table_readers_total = MEM_TABLE_READERS_TOTAL
588        .get_metric_with_label_values(&[path.as_str()])
589        .unwrap();
590    let cache_total = CACHE_TOTAL
591        .get_metric_with_label_values(&[path.as_str()])
592        .unwrap();
593
594    loop {
595        match db.upgrade() {
596            Some(db) => {
597                match get_memory_usage_stats(Some(&[&db]), None) {
598                    Ok(stats) => {
599                        mem_table_total.set(stats.mem_table_total as i64);
600                        mem_table_unflushed.set(stats.mem_table_unflushed as i64);
601                        mem_table_readers_total.set(stats.mem_table_readers_total as i64);
602                        cache_total.set(stats.cache_total as i64);
603                    }
604                    Err(err) => {
605                        tracing::error!("error getting stats: {err:#}");
606                    }
607                };
608            }
609            None => {
610                // Dead
611                return;
612            }
613        }
614        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
615    }
616}