Skip to content

Commit

Permalink
Add TensorBase::{as_cow, into_cow}
Browse files Browse the repository at this point in the history
This is useful for code which needs to conditionally copy or create a tensor,
and then get the results of the copying and non-copying code paths into the same
type. See robertknight/ocrs#57 for a downstream solution
that can be replaced by this.
  • Loading branch information
robertknight committed May 5, 2024
1 parent 6893c28 commit bf53586
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
42 changes: 35 additions & 7 deletions rten-tensor/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,46 @@ impl<'a, T> StorageMut for ViewMutData<'a, T> {
}
}

impl<'a, T> Storage for Cow<'a, [T]>
where
[T]: ToOwned,
{
/// Tensor storage which may be either owned or borrowed.
///
/// The name is taken from [std::borrow::Cow] in the standard library,
/// which is conceptually similar.
pub enum CowData<'a, T> {
/// A [CowData] that owns its data.
Owned(Vec<T>),
/// A [CowData] that borrows data.
Borrowed(ViewData<'a, T>),
}

impl<'a, T> Storage for CowData<'a, T> {
type Elem = T;

fn len(&self) -> usize {
self.as_ref().len()
match self {
CowData::Owned(vec) => vec.len(),
CowData::Borrowed(view) => view.len(),
}
}

fn as_ptr(&self) -> *const T {
self.as_ref().as_ptr()
match self {
CowData::Owned(vec) => vec.as_ptr(),
CowData::Borrowed(view) => view.as_ptr(),
}
}
}

impl<'a, T> IntoStorage for Cow<'a, [T]>
where
[T]: ToOwned<Owned = Vec<T>>,
{
type Output = CowData<'a, T>;

fn into_storage(self) -> Self::Output {
match self {
Cow::Owned(vec) => CowData::Owned(vec),
Cow::Borrowed(slice) => CowData::Borrowed(slice.into_storage()),
}
}
}

Expand Down Expand Up @@ -401,7 +429,7 @@ mod tests {
let view: ViewData<i32> = data.as_slice().into_storage();
test_storage_impl(view, data);

let cow_view = Cow::Borrowed(data.as_slice());
let cow_view = Cow::Borrowed(data.as_slice()).into_storage();
test_storage_impl(cow_view, data);

let mut_view: ViewMutData<i32> = data.as_mut_slice().into_storage();
Expand Down
54 changes: 49 additions & 5 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::layout::{
AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, MatrixLayout, MutLayout, NdLayout,
OverlapPolicy, ResizeLayout,
};
use crate::storage::{IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
use crate::storage::{CowData, IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
use crate::transpose::copy_contiguous;
use crate::{Alloc, GlobalAlloc, IntoSliceItems, RandomSource, SliceItem};

Expand Down Expand Up @@ -61,6 +61,18 @@ pub trait AsView: Layout {
/// Return the layout of this tensor.
fn layout(&self) -> &Self::Layout;

/// Return a view of this tensor using a borrowed [CowData] for storage.
///
/// Together with [`into_cow`](TensorBase::into_cow), this is useful where
/// code needs to conditionally copy or create a new tensor, and get either
/// the borrowed or owned tensor into the same type.
fn as_cow(&self) -> TensorBase<CowData<Self::Elem>, Self::Layout>
where
[Self::Elem]: ToOwned,
{
self.view().as_cow()
}

/// Return a view of this tensor with a dynamic rank.
fn as_dyn(&self) -> TensorBase<ViewData<Self::Elem>, DynLayout> {
self.view().as_dyn()
Expand Down Expand Up @@ -274,7 +286,7 @@ pub trait AsView: Layout {
/// data into a new buffer otherwise.
///
/// Certain operations require or are faster with contiguous tensors.
fn to_contiguous(&self) -> TensorBase<Cow<[Self::Elem]>, Self::Layout>
fn to_contiguous(&self) -> TensorBase<CowData<Self::Elem>, Self::Layout>
where
Self::Elem: Clone,
{
Expand Down Expand Up @@ -677,6 +689,18 @@ impl<T, L: Clone + MutLayout> TensorBase<Vec<T>, L> {
self.data.truncate(range.end - range.start);
}

/// Convert the storage of this tensor into an owned [CowData].
///
/// This is useful in contexts where code needs to conditionally copy or
/// create a new tensor. See [AsView::as_cow].
pub fn into_cow(self) -> TensorBase<CowData<'static, T>, L> {
let TensorBase { data, layout } = self;
TensorBase {
layout,
data: CowData::Owned(data),
}
}

/// Consume self and return the underlying data as a contiguous tensor.
///
/// See also [TensorBase::to_vec].
Expand Down Expand Up @@ -940,6 +964,16 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
}
}

/// Convert the storage of this view to a borrowed [CowData].
///
/// See [AsView::as_cow].
pub fn as_cow(&self) -> TensorBase<CowData<'a, T>, L> {
TensorBase {
layout: self.layout.clone(),
data: CowData::Borrowed(self.data),
}
}

/// Broadcast this view to another shape.
///
/// See [AsView::broadcast].
Expand Down Expand Up @@ -1118,19 +1152,19 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
///
/// If the data is already contiguous, no copy is made, otherwise the
/// elements are copied into a new buffer in contiguous order.
pub fn to_contiguous(&self) -> TensorBase<Cow<'a, [T]>, L>
pub fn to_contiguous(&self) -> TensorBase<CowData<'a, T>, L>
where
T: Clone,
{
if let Some(data) = self.data() {
TensorBase {
data: Cow::Borrowed(data),
data: CowData::Borrowed(data.into_storage()),
layout: self.layout.clone(),
}
} else {
let data = self.to_vec();
TensorBase {
data: Cow::Owned(data),
data: CowData::Owned(data),
layout: L::from_shape(self.layout.shape()),
}
}
Expand Down Expand Up @@ -1825,6 +1859,16 @@ mod tests {
assert_eq!(y.data(), Some([2, 3, 4, 5].as_slice()));
}

#[test]
fn test_as_cow_into_cow() {
for copy in [true, false] {
let x = Tensor::arange(0, 4, None).into_shape([2, 2]);
let cow_x = if copy { x.into_cow() } else { x.as_cow() };
assert_eq!(cow_x.shape(), [2, 2]);
assert_eq!(cow_x.data().unwrap(), &[0, 1, 2, 3]);
}
}

#[test]
fn test_as_dyn() {
let data = vec![1., 2., 3., 4.];
Expand Down

0 comments on commit bf53586

Please sign in to comment.