diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 7f856d9ac6b..e46cd7eb840 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -16,7 +16,7 @@ use chroma_error::ErrorCodes; use futures::future::join_all; use parking_lot::Mutex; use std::mem::transmute; -use std::ops::Bound; +use std::ops::RangeBounds; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use uuid::Uuid; @@ -417,95 +417,22 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me } } - /// Returns all arrow records whose key > supplied key. - pub(crate) async fn get_gt<'a>( + // Returns all Arrow records in the specified range. + pub(crate) async fn get_range<'prefix, PrefixRange, KeyRange>( &'me self, - prefix: &'a str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys > key from sparse index for this prefix. - let block_ids = self.root.sparse_index.get_block_ids_range( - prefix..=prefix, - ( - std::ops::Bound::Excluded(key.clone()), - std::ops::Bound::Unbounded, - ), - ); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys > key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range( - prefix..=prefix, - (Bound::Excluded(key.clone()), Bound::Unbounded), - )); - } - Ok(result) - } - - /// Returns all arrow records whose key < supplied key. - pub(crate) async fn get_lt( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys < key from sparse index. + prefix_range: PrefixRange, + key_range: KeyRange, + ) -> Result, Box> + where + PrefixRange: RangeBounds<&'prefix str> + Clone, + KeyRange: RangeBounds + Clone, + { let block_ids = self .root .sparse_index - .get_block_ids_range(prefix..=prefix, ..key.clone()); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys < key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; + .get_block_ids_range(prefix_range.clone(), key_range.clone()); - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range(prefix..=prefix, ..key.clone())); - } - Ok(result) - } - - /// Returns all arrow records whose key >= supplied key. - pub(crate) async fn get_gte( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys >= key from sparse index. - let block_ids = self - .root - .sparse_index - .get_block_ids_range(prefix..=prefix, key.clone()..); let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys >= key. for block_id in block_ids { let block_opt = match self.get_block(block_id).await { Ok(Some(block)) => Some(block), @@ -523,76 +450,9 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_range(prefix..=prefix, key.clone()..)); + result.extend(block.get_range(prefix_range.clone(), key_range.clone())); } - Ok(result) - } - - /// Returns all arrow records whose key <= supplied key. - pub(crate) async fn get_lte( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys <= key from sparse index. - let block_ids = self - .root - .sparse_index - .get_block_ids_range(prefix..=prefix, ..=key.clone()); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys <= key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range(prefix..=prefix, ..=key.clone())); - } - Ok(result) - } - /// Returns all arrow records whose prefix is same as supplied prefix. - pub(crate) async fn get_by_prefix( - &'me self, - prefix: &str, - ) -> Result, Box> { - let block_ids = self - .root - .sparse_index - .get_block_ids_range::(prefix..=prefix, ..); - let mut result: Vec<(K, V)> = vec![]; - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - - result.extend(block.get_range(prefix..=prefix, ..)); - } Ok(result) } diff --git a/rust/blockstore/src/types.rs b/rust/blockstore/src/types.rs index 89ec4180700..f52960ee8bc 100644 --- a/rust/blockstore/src/types.rs +++ b/rust/blockstore/src/types.rs @@ -13,6 +13,7 @@ use chroma_types::DataRecord; use roaring::RoaringBitmap; use std::fmt::{Debug, Display}; use std::mem::size_of; +use std::ops::Bound; use thiserror::Error; #[derive(Debug, Error)] @@ -282,7 +283,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_by_prefix(prefix), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_by_prefix(prefix).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..).await + } } } @@ -293,7 +296,11 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gt(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gt(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader + .get_range(prefix..=prefix, (Bound::Excluded(key), Bound::Unbounded)) + .await + } } } @@ -304,7 +311,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lt(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lt(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..key).await + } } } @@ -315,7 +324,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gte(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gte(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, key..).await + } } } @@ -326,7 +337,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lte(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lte(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..=key).await + } } }