Skip to content

Commit

Permalink
[ENH]: replace get_* methods on Arrow blocks with get_range()
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Oct 16, 2024
1 parent 78f69e7 commit 843bb3a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 99 deletions.
155 changes: 61 additions & 94 deletions rust/blockstore/src/arrow/block/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cmp::Ordering::{Equal, Greater, Less};
use std::collections::HashMap;
use std::io::SeekFrom;
use std::ops::{Bound, RangeBounds};

use crate::arrow::types::{ArrowReadableKey, ArrowReadableValue};
use arrow::array::ArrayData;
Expand Down Expand Up @@ -214,32 +215,6 @@ impl Block {
)
}

#[inline]
fn scan_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
range: impl Iterator<Item = usize>,
) -> Vec<(K, V)> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("The prefix array should be a string arrary.");
let mut result = Vec::new();
for index in range {
if prefix_array.value(index) == prefix {
result.push((
K::get(self.data.column(1), index),
V::get(self.data.column(2), index),
));
} else {
break;
}
}
result
}

/*
===== Block Queries =====
*/
Expand All @@ -260,80 +235,72 @@ impl Block {
}
}

/// Get all the values for a given prefix in the block
/// Get all the values for a given prefix & key range in the block
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
/// - If at least one end of the prefix range is excluded (currently unsupported)
pub fn get_range<
'prefix,
'me,
K: ArrowReadableKey<'me>,
V: ArrowReadableValue<'me>,
PrefixRange,
KeyRange,
>(
&'me self,
prefix: &str,
) -> Vec<(K, V)> {
self.scan_prefix(
prefix,
self.binary_search_index(prefix, Option::<&K>::None)..self.len(),
)
}

/// Get all the values for a given prefix in the block where the key is greater than the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_gt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Vec<(K, V)> {
let index = self.binary_search_index(prefix, Some(&key));
if self.match_prefix_key_at_index(prefix, &key, index) {
self.scan_prefix(prefix, index + 1..self.len())
} else {
self.scan_prefix(prefix, index..self.len())
}
}
prefix_range: PrefixRange,
key_range: KeyRange,
) -> Vec<(K, V)>
where
PrefixRange: RangeBounds<&'prefix str>,
KeyRange: RangeBounds<K>,
{
let start_index = match prefix_range.start_bound() {
Bound::Included(prefix) => match key_range.start_bound() {
Bound::Included(key) => self.binary_search_index(prefix, Some(key)),
Bound::Excluded(key) => {
let index = self.binary_search_index(prefix, Some(key));
if self.match_prefix_key_at_index(prefix, key, index) {
index + 1
} else {
index
}
}
Bound::Unbounded => self.binary_search_index::<K>(prefix, None),
},
Bound::Excluded(_) => {
unimplemented!("Excluded prefix range is not currently supported")
}
Bound::Unbounded => 0,
};

/// Get all the values for a given prefix in the block where the key is greater than or equal to the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_gte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Vec<(K, V)> {
self.scan_prefix(
prefix,
self.binary_search_index(prefix, Some(&key))..self.len(),
)
}
let end_index = match prefix_range.end_bound() {
Bound::Included(prefix) => match key_range.end_bound() {
Bound::Included(key) => {
let index = self.binary_search_index(prefix, Some(key));
if self.match_prefix_key_at_index(prefix, key, index) {
index + 1
} else {
index
}
}
Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)),
Bound::Unbounded => self.len(),
},
Bound::Excluded(_) => {
unimplemented!("Excluded prefix range is not currently supported")
}
Bound::Unbounded => self.len(),
};

/// Get all the values for a given prefix in the block where the key is less than the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_lt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Vec<(K, V)> {
let mut result = self.scan_prefix(
prefix,
(0..self.binary_search_index(prefix, Some(&key))).rev(),
);
result.reverse();
result
}
let mut result = Vec::new();

/// Get all the values for a given prefix in the block where the key is less than or equal to the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_lte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Vec<(K, V)> {
let index = self.binary_search_index(prefix, Some(&key));
let mut result = if self.match_prefix_key_at_index(prefix, &key, index) {
self.scan_prefix(prefix, (0..=index).rev())
} else {
self.scan_prefix(prefix, (0..index).rev())
};
result.reverse();
for index in start_index..end_index {
result.push((
K::get(self.data.column(1), index),
V::get(self.data.column(2), index),
));
}
result
}

Expand Down
14 changes: 9 additions & 5 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use chroma_error::ErrorCodes;
use futures::future::join_all;
use parking_lot::Mutex;
use std::mem::transmute;
use std::ops::Bound;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use uuid::Uuid;
Expand Down Expand Up @@ -449,7 +450,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_gt(prefix, key.clone()));
result.extend(block.get_range(
prefix..=prefix,
(Bound::Excluded(key.clone()), Bound::Unbounded),
));
}
Ok(result)
}
Expand Down Expand Up @@ -484,7 +488,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_lt(prefix, key.clone()));
result.extend(block.get_range(prefix..=prefix, ..key.clone()));
}
Ok(result)
}
Expand Down Expand Up @@ -519,7 +523,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_gte(prefix, key.clone()));
result.extend(block.get_range(prefix..=prefix, key.clone()..));
}
Ok(result)
}
Expand Down Expand Up @@ -554,7 +558,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
}
};
result.extend(block.get_lte(prefix, key.clone()));
result.extend(block.get_range(prefix..=prefix, ..=key.clone()));
}
Ok(result)
}
Expand Down Expand Up @@ -587,7 +591,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
}
};

result.extend(block.get_prefix(prefix));
result.extend(block.get_range(prefix..=prefix, ..));
}
Ok(result)
}
Expand Down

0 comments on commit 843bb3a

Please sign in to comment.