diff --git a/src/ops/concat.rs b/src/ops/concat.rs index c7ff04f7..bf08016b 100644 --- a/src/ops/concat.rs +++ b/src/ops/concat.rs @@ -5,10 +5,10 @@ use rten_tensor::{NdTensorView, Tensor, TensorView}; use smallvec::SmallVec; +use crate::ops::static_dims; use crate::ops::{ resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList, }; -use crate::static_dims; use crate::tensor_pool::{AutoReturn, TensorPool}; /// Return the shape formed by concatenating all tensors along a given axis. diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 324bd8ed..1295ff6d 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -8,9 +8,8 @@ use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView}; use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, VirtualMatrix}; use crate::ops::pooling::calc_output_size_and_padding; -use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding}; +use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList, Padding}; use crate::tensor_pool::{AutoReturn, TensorPool}; -use crate::{check_dims, static_dims}; mod depthwise; mod im2col; @@ -105,7 +104,7 @@ where // Handle 1D convolution by expanding to 2D and then removing the extra // dimension from the result. if let &[_n, _c, _w] = input.shape() { - let [_out_c, _k_in_c, _k_w] = check_dims!(kernel, 3, "OCW"); + let [_out_c, _k_in_c, _k_w] = static_dims!(kernel, 3, "OCW")?.shape(); let mut input_2d = input.clone(); input_2d.insert_axis(2); @@ -152,7 +151,7 @@ where let kernel = static_dims!(kernel, 4, "OCHW")?; let [out_c, k_in_c, k_h, k_w] = kernel.shape(); - check_dims!(bias?, 1); + static_dims!(bias?, 1).transpose()?; let input = input.view(); let kernel = kernel.view(); @@ -469,7 +468,7 @@ pub fn conv_transpose( // Handle 1D transposed convolution by expanding to 2D and then removing // the extra dimension from the result. if let &[n, c, w] = input.shape() { - let [out_c, k_in_c, k_w] = check_dims!(kernel, 3, "OCW"); + let [out_c, k_in_c, k_w] = static_dims!(kernel, 3, "OCW")?.shape(); let mut input_2d = input.clone(); input_2d.reshape(&[n, c, 1, w]); @@ -499,7 +498,7 @@ pub fn conv_transpose( let [batch, in_c, in_h, in_w] = input.shape(); let kernel = static_dims!(kernel, 4, "OCHW")?; let [k_in_c, out_c, k_h, k_w] = kernel.shape(); - check_dims!(bias?, 1); + static_dims!(bias?, 1).transpose()?; let bias = bias.map(|b| b.nd_view()); diff --git a/src/ops/generate.rs b/src/ops/generate.rs index 74797fbe..dbf24edb 100644 --- a/src/ops/generate.rs +++ b/src/ops/generate.rs @@ -4,10 +4,9 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, Tensor, TensorView}; use crate::ops::{ - resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList, - Scalar, + resolve_axis, resolve_index, static_dims, Input, InputList, IntoOpResult, OpError, Operator, + OutputList, Scalar, }; -use crate::static_dims; use crate::tensor_pool::TensorPool; pub fn constant_of_shape( diff --git a/src/ops/layout.rs b/src/ops/layout.rs index a75d6026..d9db9502 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -7,10 +7,9 @@ use smallvec::SmallVec; use crate::ops::binary_elementwise::{broadcast_shapes, fast_broadcast_cycles_repeats}; use crate::ops::{ - resolve_axes, resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output, - OutputList, + resolve_axes, resolve_axis, static_dims, Input, InputList, IntoOpResult, OpError, Operator, + Output, OutputList, }; -use crate::static_dims; use crate::tensor_pool::TensorPool; /// Return the tensor shape resulting from broadcasting `input_shape` with `shape`. diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 5ad296ca..5840e536 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -3,12 +3,11 @@ use rayon::prelude::*; use rten_tensor::prelude::*; use rten_tensor::{Tensor, TensorView}; -use crate::check_dims; use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT}; use crate::iter_util::range_chunks; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; -use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList}; +use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList}; use crate::tensor_pool::{AutoReturn, TensorPool}; /// Compute the General Matrix Multiplication (GEMM) `c = alpha * (ab) + beta * c`. @@ -30,8 +29,8 @@ pub fn gemm_op( where GemmExecutor: Default, { - check_dims!(a, 2); - check_dims!(b, 2); + let a = static_dims!(a, 2)?; + let b = static_dims!(b, 2)?; let a = if transpose_a { a.transposed() } else { a }; let b = if transpose_b { b.transposed() } else { b }; diff --git a/src/ops/mod.rs b/src/ops/mod.rs index a57c74dc..e56bebc4 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -706,71 +706,12 @@ impl Display for OpError { impl Error for OpError {} -/// Check that a tensor has an expected number of dimensions, or return an -/// `OpError::InvalidValue`. +/// Convert a tensor with dynamic dimension count to a view with a static +/// dimension count. /// -/// Can be used with `check_dims!(input, expected_rank)` if `input` is a -/// `Tensor` or `check_dims!(input?, expected_rank)` if `input` is an -/// `Option>`. -/// -/// If `$ndim` is a literal, the macro returns an array of `$ndim` sizes for -/// each dimension. This conveniently allows checking the rank of a tensor -/// and extracting the sizes of dimension in one call. For example: -/// `let [rows, cols] = check_dims!(matrix, 2)`. When `$ndim` is a literal, -/// a third argument can also be passed to specify the names of the dimensions, -/// eg. "NCHW" or "dir, batch, seq". This can produce more helpful errors if -/// the input does not match the expected shape. -#[doc(hidden)] -#[macro_export] -macro_rules! check_dims { - ($tensor:ident, $ndim:literal, $dim_names:literal) => {{ - let shape: [usize; $ndim] = $tensor.shape().try_into().map_err(|_| { - OpError::InvalidValue(concat!( - stringify!($tensor), - " must have ", - stringify!($ndim), - " dims (", - $dim_names, - ")" - )) - })?; - shape - }}; - - ($tensor:ident, $ndim:literal) => {{ - let shape: [usize; $ndim] = $tensor.shape().try_into().map_err(|_| { - OpError::InvalidValue(concat!( - stringify!($tensor), - " must have ", - stringify!($ndim), - " dims" - )) - })?; - shape - }}; - - ($tensor:ident, $ndim:expr) => { - if $tensor.ndim() != $ndim { - return Err(OpError::InvalidValue(concat!( - stringify!($tensor), - " must have ", - stringify!($ndim), - " dims" - ))); - } - }; - - ($tensor:ident?, $ndim: expr) => { - if let Some($tensor) = $tensor.as_ref() { - check_dims!($tensor, $ndim); - } - }; -} - -/// Convert a tensor with dynamic dimension count to an `NdTensorView`, or -/// return an `OpError::InvalidValue` if the dimension count is incorrect. -#[doc(hidden)] -#[macro_export] +/// If the conversion fails an `OpError::InvalidValue` error will be returned +/// with a message that includes the name of the tensor and, optionally, the +/// names of the expected dimensions (eg. "NCHW"). macro_rules! static_dims { ($tensor:ident, $ndim:literal, $dim_names:literal) => {{ use rten_tensor::prelude::*; @@ -803,8 +744,18 @@ macro_rules! static_dims { Ok($tensor.nd_view::<$ndim>()) } }}; + + ($tensor:ident?, $ndim: expr) => { + if let Some($tensor) = $tensor.as_ref() { + Some(static_dims!($tensor, $ndim)) + } else { + None + } + }; } +pub(crate) use static_dims; + /// Outputs from an operator. /// /// This avoids allocations in the common case where an operator produces diff --git a/src/ops/non_max_suppression.rs b/src/ops/non_max_suppression.rs index eefc032d..cb7252e1 100644 --- a/src/ops/non_max_suppression.rs +++ b/src/ops/non_max_suppression.rs @@ -1,8 +1,7 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; -use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList}; -use crate::static_dims; +use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList}; use crate::tensor_pool::TensorPool; #[derive(Copy, Clone, Debug, PartialEq)] diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 6017302f..640d0dce 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -6,10 +6,9 @@ use rten_vecmath::vec_softmax_in_place; use smallvec::SmallVec; use crate::ops::reduce::reduce_inverse_rms; -use crate::ops::{add_in_place, mul_in_place, reduce_mean, sub}; +use crate::ops::{add_in_place, mul_in_place, reduce_mean, static_dims, sub}; use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList}; use crate::slice_reductions::{slice_max, slice_sum}; -use crate::static_dims; use crate::tensor_pool::{AutoReturn, TensorPool}; /// Perform in-place batch normalization on the `NC*` tensor `out`. diff --git a/src/ops/pad.rs b/src/ops/pad.rs index a38211ad..c4e79ab5 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -1,8 +1,7 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView}; -use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList}; -use crate::static_dims; +use crate::ops::{static_dims, Input, InputList, IntoOpResult, OpError, Operator, OutputList}; use crate::tensor_pool::TensorPool; #[derive(Copy, Clone, Debug, PartialEq)] diff --git a/src/ops/pooling.rs b/src/ops/pooling.rs index e864ae7d..ece42ed1 100644 --- a/src/ops/pooling.rs +++ b/src/ops/pooling.rs @@ -5,8 +5,7 @@ use rayon::prelude::*; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView, TensorViewMut}; -use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList, Padding}; -use crate::static_dims; +use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList, Padding}; use crate::tensor_pool::TensorPool; /// Calculate the output size and padding for a convolution or pooling operation. diff --git a/src/ops/resize.rs b/src/ops/resize.rs index dc4c0947..2b0a252a 100644 --- a/src/ops/resize.rs +++ b/src/ops/resize.rs @@ -6,9 +6,8 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView}; use crate::iter_util::range_chunks; -use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList}; +use crate::ops::{static_dims, Input, InputList, IntoOpResult, OpError, Operator, OutputList}; use crate::tensor_pool::TensorPool; -use crate::{check_dims, static_dims}; /// Specifies an output size for a resize operation. pub enum ResizeTarget<'a> { @@ -226,7 +225,7 @@ fn bilinear_resize( /// /// This is a simplified API for [resize]. pub fn resize_image(input: TensorView, size: [usize; 2]) -> Result { - let [batch, chans, _height, _width] = check_dims!(input, 4); + let [batch, chans, _height, _width] = static_dims!(input, 4)?.shape(); let [out_height, out_width] = size; let out_shape = [batch, chans, out_height, out_width].map(|x| x as i32); resize( diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 5f6ce525..a89b1692 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -6,10 +6,9 @@ use rten_tensor::{NdTensor, Tensor, TensorView}; use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::{ - add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator, - OutputList, + add_in_place, mul_in_place, sigmoid, static_dims, tanh, InputList, IntoOpResult, OpError, + Operator, OutputList, }; -use crate::static_dims; use crate::tensor_pool::{AutoReturn, TensorPool}; /// Direction that an RNN operator will traverse the input sequence in. diff --git a/src/ops/slice.rs b/src/ops/slice.rs index f9c9e23e..07944e07 100644 --- a/src/ops/slice.rs +++ b/src/ops/slice.rs @@ -4,9 +4,9 @@ use rten_tensor::{NdTensorView, SliceItem, SliceRange, Tensor, TensorView}; use smallvec::SmallVec; use crate::ops::{ - resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList, + resolve_axis, static_dims, Input, InputList, IntoOpResult, OpError, Operator, Output, + OutputList, }; -use crate::static_dims; use crate::tensor_pool::TensorPool; /// Compute the effective starts, ends and steps for each input dimension in diff --git a/src/ops/split.rs b/src/ops/split.rs index bac24b9b..d9cfc863 100644 --- a/src/ops/split.rs +++ b/src/ops/split.rs @@ -1,8 +1,7 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView}; -use crate::ops::{resolve_axis, InputList, OpError, Operator, OutputList}; -use crate::static_dims; +use crate::ops::{resolve_axis, static_dims, InputList, OpError, Operator, OutputList}; use crate::tensor_pool::TensorPool; pub fn split(