Skip to content

Commit

Permalink
Propagate
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Oct 16, 2024
1 parent 843bb3a commit 20ca271
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 157 deletions.
164 changes: 12 additions & 152 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use chroma_error::ErrorCodes;
use futures::future::join_all;
use parking_lot::Mutex;
use std::mem::transmute;
use std::ops::Bound;
use std::ops::RangeBounds;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use uuid::Uuid;
Expand Down Expand Up @@ -417,95 +417,22 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
}
}

/// Returns all arrow records whose key > supplied key.
pub(crate) async fn get_gt<'a>(
// Returns all Arrow records in the specified range.
pub(crate) async fn get_range<'prefix, PrefixRange, KeyRange>(
&'me self,
prefix: &'a str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys > key from sparse index for this prefix.
let block_ids = self.root.sparse_index.get_block_ids_range(
prefix..=prefix,
(
std::ops::Bound::Excluded(key.clone()),
std::ops::Bound::Unbounded,
),
);
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys > key.
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Ok(Some(block)) => Some(block),
Ok(None) => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
Err(e) => {
return Err(Box::new(e));
}
};

let block = match block_opt {
Some(b) => b,
None => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_range(
prefix..=prefix,
(Bound::Excluded(key.clone()), Bound::Unbounded),
));
}
Ok(result)
}

/// Returns all arrow records whose key < supplied key.
pub(crate) async fn get_lt(
&'me self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys < key from sparse index.
prefix_range: PrefixRange,
key_range: KeyRange,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>>
where
PrefixRange: RangeBounds<&'prefix str> + Clone,
KeyRange: RangeBounds<K> + Clone,
{
let block_ids = self
.root
.sparse_index
.get_block_ids_range(prefix..=prefix, ..key.clone());
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys < key.
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Ok(Some(block)) => Some(block),
Ok(None) => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
Err(e) => {
return Err(Box::new(e));
}
};
.get_block_ids_range(prefix_range.clone(), key_range.clone());

let block = match block_opt {
Some(b) => b,
None => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_range(prefix..=prefix, ..key.clone()));
}
Ok(result)
}

/// Returns all arrow records whose key >= supplied key.
pub(crate) async fn get_gte(
&'me self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys >= key from sparse index.
let block_ids = self
.root
.sparse_index
.get_block_ids_range(prefix..=prefix, key.clone()..);
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys >= key.
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Ok(Some(block)) => Some(block),
Expand All @@ -523,76 +450,9 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_range(prefix..=prefix, key.clone()..));
result.extend(block.get_range(prefix_range.clone(), key_range.clone()));
}
Ok(result)
}

/// Returns all arrow records whose key <= supplied key.
pub(crate) async fn get_lte(
&'me self,
prefix: &str,
key: K,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
// Get all block ids that contain keys <= key from sparse index.
let block_ids = self
.root
.sparse_index
.get_block_ids_range(prefix..=prefix, ..=key.clone());
let mut result: Vec<(K, V)> = vec![];
// Read all the blocks individually to get keys <= key.
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Ok(Some(block)) => Some(block),
Ok(None) => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
Err(e) => {
return Err(Box::new(e));
}
};

let block = match block_opt {
Some(b) => b,
None => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_range(prefix..=prefix, ..=key.clone()));
}
Ok(result)
}

/// Returns all arrow records whose prefix is same as supplied prefix.
pub(crate) async fn get_by_prefix(
&'me self,
prefix: &str,
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
let block_ids = self
.root
.sparse_index
.get_block_ids_range::<K, _, _>(prefix..=prefix, ..);
let mut result: Vec<(K, V)> = vec![];
for block_id in block_ids {
let block_opt = match self.get_block(block_id).await {
Ok(Some(block)) => Some(block),
Ok(None) => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
Err(e) => {
return Err(Box::new(e));
}
};

let block = match block_opt {
Some(b) => b,
None => {
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};

result.extend(block.get_range(prefix..=prefix, ..));
}
Ok(result)
}

Expand Down
23 changes: 18 additions & 5 deletions rust/blockstore/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use chroma_types::DataRecord;
use roaring::RoaringBitmap;
use std::fmt::{Debug, Display};
use std::mem::size_of;
use std::ops::Bound;
use thiserror::Error;

#[derive(Debug, Error)]
Expand Down Expand Up @@ -282,7 +283,9 @@ impl<
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
match self {
BlockfileReader::MemoryBlockfileReader(reader) => reader.get_by_prefix(prefix),
BlockfileReader::ArrowBlockfileReader(reader) => reader.get_by_prefix(prefix).await,
BlockfileReader::ArrowBlockfileReader(reader) => {
reader.get_range(prefix..=prefix, ..).await
}
}
}

Expand All @@ -293,7 +296,11 @@ impl<
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
match self {
BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gt(prefix, key),
BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gt(prefix, key).await,
BlockfileReader::ArrowBlockfileReader(reader) => {
reader
.get_range(prefix..=prefix, (Bound::Excluded(key), Bound::Unbounded))
.await
}
}
}

Expand All @@ -304,7 +311,9 @@ impl<
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
match self {
BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lt(prefix, key),
BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lt(prefix, key).await,
BlockfileReader::ArrowBlockfileReader(reader) => {
reader.get_range(prefix..=prefix, ..key).await
}
}
}

Expand All @@ -315,7 +324,9 @@ impl<
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
match self {
BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gte(prefix, key),
BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gte(prefix, key).await,
BlockfileReader::ArrowBlockfileReader(reader) => {
reader.get_range(prefix..=prefix, key..).await
}
}
}

Expand All @@ -326,7 +337,9 @@ impl<
) -> Result<Vec<(K, V)>, Box<dyn ChromaError>> {
match self {
BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lte(prefix, key),
BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lte(prefix, key).await,
BlockfileReader::ArrowBlockfileReader(reader) => {
reader.get_range(prefix..=prefix, ..=key).await
}
}
}

Expand Down

0 comments on commit 20ca271

Please sign in to comment.