Skip to content

Commit

Permalink
Replace remaining uses of check_dims with static_dims
Browse files Browse the repository at this point in the history
Avoid having two different ways for an operator to check that an input has the
expected number of dimensions.

In the process convert `static_dims` to a true crate-internal macro instead of a
public-but-undocumented one.
  • Loading branch information
robertknight committed Oct 15, 2024
1 parent ba13b33 commit 632f0fd
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 99 deletions.
2 changes: 1 addition & 1 deletion src/ops/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use rten_tensor::{NdTensorView, Tensor, TensorView};

use smallvec::SmallVec;

use crate::ops::static_dims;
use crate::ops::{
resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList,
};
use crate::static_dims;
use crate::tensor_pool::{AutoReturn, TensorPool};

/// Return the shape formed by concatenating all tensors along a given axis.
Expand Down
11 changes: 5 additions & 6 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView};

use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, VirtualMatrix};
use crate::ops::pooling::calc_output_size_and_padding;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::{check_dims, static_dims};

mod depthwise;
mod im2col;
Expand Down Expand Up @@ -105,7 +104,7 @@ where
// Handle 1D convolution by expanding to 2D and then removing the extra
// dimension from the result.
if let &[_n, _c, _w] = input.shape() {
let [_out_c, _k_in_c, _k_w] = check_dims!(kernel, 3, "OCW");
let [_out_c, _k_in_c, _k_w] = static_dims!(kernel, 3, "OCW")?.shape();

let mut input_2d = input.clone();
input_2d.insert_axis(2);
Expand Down Expand Up @@ -152,7 +151,7 @@ where

let kernel = static_dims!(kernel, 4, "OCHW")?;
let [out_c, k_in_c, k_h, k_w] = kernel.shape();
check_dims!(bias?, 1);
static_dims!(bias?, 1).transpose()?;

let input = input.view();
let kernel = kernel.view();
Expand Down Expand Up @@ -469,7 +468,7 @@ pub fn conv_transpose(
// Handle 1D transposed convolution by expanding to 2D and then removing
// the extra dimension from the result.
if let &[n, c, w] = input.shape() {
let [out_c, k_in_c, k_w] = check_dims!(kernel, 3, "OCW");
let [out_c, k_in_c, k_w] = static_dims!(kernel, 3, "OCW")?.shape();

let mut input_2d = input.clone();
input_2d.reshape(&[n, c, 1, w]);
Expand Down Expand Up @@ -499,7 +498,7 @@ pub fn conv_transpose(
let [batch, in_c, in_h, in_w] = input.shape();
let kernel = static_dims!(kernel, 4, "OCHW")?;
let [k_in_c, out_c, k_h, k_w] = kernel.shape();
check_dims!(bias?, 1);
static_dims!(bias?, 1).transpose()?;

let bias = bias.map(|b| b.nd_view());

Expand Down
5 changes: 2 additions & 3 deletions src/ops/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};

use crate::ops::{
resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
Scalar,
resolve_axis, resolve_index, static_dims, Input, InputList, IntoOpResult, OpError, Operator,
OutputList, Scalar,
};
use crate::static_dims;
use crate::tensor_pool::TensorPool;

pub fn constant_of_shape<T: Copy>(
Expand Down
5 changes: 2 additions & 3 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ use smallvec::SmallVec;

use crate::ops::binary_elementwise::{broadcast_shapes, fast_broadcast_cycles_repeats};
use crate::ops::{
resolve_axes, resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output,
OutputList,
resolve_axes, resolve_axis, static_dims, Input, InputList, IntoOpResult, OpError, Operator,
Output, OutputList,
};
use crate::static_dims;
use crate::tensor_pool::TensorPool;

/// Return the tensor shape resulting from broadcasting `input_shape` with `shape`.
Expand Down
7 changes: 3 additions & 4 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT};
use crate::iter_util::range_chunks;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::tensor_pool::{AutoReturn, TensorPool};

/// Compute the General Matrix Multiplication (GEMM) `c = alpha * (ab) + beta * c`.
Expand All @@ -30,8 +29,8 @@ pub fn gemm_op<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
check_dims!(a, 2);
check_dims!(b, 2);
let a = static_dims!(a, 2)?;
let b = static_dims!(b, 2)?;

let a = if transpose_a { a.transposed() } else { a };
let b = if transpose_b { b.transposed() } else { b };
Expand Down
79 changes: 15 additions & 64 deletions src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,71 +706,12 @@ impl Display for OpError {

impl Error for OpError {}

/// Check that a tensor has an expected number of dimensions, or return an
/// `OpError::InvalidValue`.
/// Convert a tensor with dynamic dimension count to a view with a static
/// dimension count.
///
/// Can be used with `check_dims!(input, expected_rank)` if `input` is a
/// `Tensor<T>` or `check_dims!(input?, expected_rank)` if `input` is an
/// `Option<Tensor<T>>`.
///
/// If `$ndim` is a literal, the macro returns an array of `$ndim` sizes for
/// each dimension. This conveniently allows checking the rank of a tensor
/// and extracting the sizes of dimension in one call. For example:
/// `let [rows, cols] = check_dims!(matrix, 2)`. When `$ndim` is a literal,
/// a third argument can also be passed to specify the names of the dimensions,
/// eg. "NCHW" or "dir, batch, seq". This can produce more helpful errors if
/// the input does not match the expected shape.
#[doc(hidden)]
#[macro_export]
macro_rules! check_dims {
($tensor:ident, $ndim:literal, $dim_names:literal) => {{
let shape: [usize; $ndim] = $tensor.shape().try_into().map_err(|_| {
OpError::InvalidValue(concat!(
stringify!($tensor),
" must have ",
stringify!($ndim),
" dims (",
$dim_names,
")"
))
})?;
shape
}};

($tensor:ident, $ndim:literal) => {{
let shape: [usize; $ndim] = $tensor.shape().try_into().map_err(|_| {
OpError::InvalidValue(concat!(
stringify!($tensor),
" must have ",
stringify!($ndim),
" dims"
))
})?;
shape
}};

($tensor:ident, $ndim:expr) => {
if $tensor.ndim() != $ndim {
return Err(OpError::InvalidValue(concat!(
stringify!($tensor),
" must have ",
stringify!($ndim),
" dims"
)));
}
};

($tensor:ident?, $ndim: expr) => {
if let Some($tensor) = $tensor.as_ref() {
check_dims!($tensor, $ndim);
}
};
}

/// Convert a tensor with dynamic dimension count to an `NdTensorView`, or
/// return an `OpError::InvalidValue` if the dimension count is incorrect.
#[doc(hidden)]
#[macro_export]
/// If the conversion fails an `OpError::InvalidValue` error will be returned
/// with a message that includes the name of the tensor and, optionally, the
/// names of the expected dimensions (eg. "NCHW").
macro_rules! static_dims {
($tensor:ident, $ndim:literal, $dim_names:literal) => {{
use rten_tensor::prelude::*;
Expand Down Expand Up @@ -803,8 +744,18 @@ macro_rules! static_dims {
Ok($tensor.nd_view::<$ndim>())
}
}};

($tensor:ident?, $ndim: expr) => {
if let Some($tensor) = $tensor.as_ref() {
Some(static_dims!($tensor, $ndim))
} else {
None
}
};
}

pub(crate) use static_dims;

/// Outputs from an operator.
///
/// This avoids allocations in the common case where an operator produces
Expand Down
3 changes: 1 addition & 2 deletions src/ops/non_max_suppression.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};

use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::static_dims;
use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::tensor_pool::TensorPool;

#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
3 changes: 1 addition & 2 deletions src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ use rten_vecmath::vec_softmax_in_place;
use smallvec::SmallVec;

use crate::ops::reduce::reduce_inverse_rms;
use crate::ops::{add_in_place, mul_in_place, reduce_mean, sub};
use crate::ops::{add_in_place, mul_in_place, reduce_mean, static_dims, sub};
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
use crate::slice_reductions::{slice_max, slice_sum};
use crate::static_dims;
use crate::tensor_pool::{AutoReturn, TensorPool};

/// Perform in-place batch normalization on the `NC*` tensor `out`.
Expand Down
3 changes: 1 addition & 2 deletions src/ops/pad.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView};

use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::static_dims;
use crate::ops::{static_dims, Input, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::tensor_pool::TensorPool;

#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
3 changes: 1 addition & 2 deletions src/ops/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView, TensorViewMut};

use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::static_dims;
use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList, Padding};
use crate::tensor_pool::TensorPool;

/// Calculate the output size and padding for a convolution or pooling operation.
Expand Down
5 changes: 2 additions & 3 deletions src/ops/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView};

use crate::iter_util::range_chunks;
use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::ops::{static_dims, Input, InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::tensor_pool::TensorPool;
use crate::{check_dims, static_dims};

/// Specifies an output size for a resize operation.
pub enum ResizeTarget<'a> {
Expand Down Expand Up @@ -226,7 +225,7 @@ fn bilinear_resize(
///
/// This is a simplified API for [resize].
pub fn resize_image(input: TensorView, size: [usize; 2]) -> Result<Tensor, OpError> {
let [batch, chans, _height, _width] = check_dims!(input, 4);
let [batch, chans, _height, _width] = static_dims!(input, 4)?.shape();
let [out_height, out_width] = size;
let out_shape = [batch, chans, out_height, out_width].map(|x| x as i32);
resize(
Expand Down
5 changes: 2 additions & 3 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ use rten_tensor::{NdTensor, Tensor, TensorView};

use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator,
OutputList,
add_in_place, mul_in_place, sigmoid, static_dims, tanh, InputList, IntoOpResult, OpError,
Operator, OutputList,
};
use crate::static_dims;
use crate::tensor_pool::{AutoReturn, TensorPool};

/// Direction that an RNN operator will traverse the input sequence in.
Expand Down
4 changes: 2 additions & 2 deletions src/ops/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use rten_tensor::{NdTensorView, SliceItem, SliceRange, Tensor, TensorView};
use smallvec::SmallVec;

use crate::ops::{
resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList,
resolve_axis, static_dims, Input, InputList, IntoOpResult, OpError, Operator, Output,
OutputList,
};
use crate::static_dims;
use crate::tensor_pool::TensorPool;

/// Compute the effective starts, ends and steps for each input dimension in
Expand Down
3 changes: 1 addition & 2 deletions src/ops/split.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView};

use crate::ops::{resolve_axis, InputList, OpError, Operator, OutputList};
use crate::static_dims;
use crate::ops::{resolve_axis, static_dims, InputList, OpError, Operator, OutputList};
use crate::tensor_pool::TensorPool;

pub fn split<T: Copy>(
Expand Down

0 comments on commit 632f0fd

Please sign in to comment.