diff --git a/ocrs/src/detection.rs b/ocrs/src/detection.rs index 9337e94..1941ebf 100644 --- a/ocrs/src/detection.rs +++ b/ocrs/src/detection.rs @@ -5,6 +5,7 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, Tensor}; use crate::preprocess::BLACK_VALUE; +use crate::tensor_util::IntoCow; /// Parameters that control post-processing of text detection model outputs. #[derive(Clone, Debug, PartialEq)] @@ -165,22 +166,28 @@ impl TextDetector { // inputs, within some limits. let pad_bottom = (in_height as i32 - img_height as i32).max(0); let pad_right = (in_width as i32 - img_width as i32).max(0); - let grey_img = if pad_bottom > 0 || pad_right > 0 { - let pads = &[0, 0, 0, 0, 0, 0, pad_bottom, pad_right]; - image.pad(pads.into(), BLACK_VALUE)? - } else { - image.as_dyn().to_tensor() - }; + let image = (pad_bottom > 0 || pad_right > 0) + .then(|| { + let pads = &[0, 0, 0, 0, 0, 0, pad_bottom, pad_right]; + image.pad(pads.into(), BLACK_VALUE) + }) + .transpose()? + .map(|t| t.into_cow()) + .unwrap_or(image.into_dyn().into_cow()); // Resize images to the text detection model's input size. - let resized_grey_img = grey_img.resize_image([in_height, in_width])?; + let image = (image.size(2) != in_height || image.size(3) != in_width) + .then(|| image.resize_image([in_height, in_width])) + .transpose()? + .map(|t| t.into_cow()) + .unwrap_or(image); // Run text detection model to compute a probability mask indicating whether // each pixel is part of a text word or not. let text_mask: Tensor = self .model .run_one( - (&resized_grey_img).into(), + image.view().into(), if debug { Some(RunOptions { timing: true, diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index 4640faa..25a5a1c 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -11,6 +11,8 @@ mod log; mod preprocess; mod recognition; +mod tensor_util; + #[cfg(test)] mod test_util; diff --git a/ocrs/src/tensor_util.rs b/ocrs/src/tensor_util.rs new file mode 100644 index 0000000..7934a3d --- /dev/null +++ b/ocrs/src/tensor_util.rs @@ -0,0 +1,37 @@ +use std::borrow::Cow; + +use rten_tensor::prelude::*; +use rten_tensor::{MutLayout, TensorBase}; + +/// Convert an owned tensor or view into one which uses a [Cow] for storage. +/// +/// This is useful for code that wants to conditionally copy a tensor, as this +/// trait can be used to convert either an owned copy or view to the same type. +pub trait IntoCow { + type Cow; + + fn into_cow(self) -> Self::Cow; +} + +impl<'a, T, L: MutLayout> IntoCow for TensorBase +where + [T]: ToOwned, +{ + type Cow = TensorBase, L>; + + fn into_cow(self) -> Self::Cow { + TensorBase::from_data(self.shape(), Cow::Borrowed(self.non_contiguous_data())) + } +} + +impl IntoCow for TensorBase, L> +where + [T]: ToOwned>, +{ + type Cow = TensorBase, L>; + + fn into_cow(self) -> Self::Cow { + let layout = self.layout().clone(); + TensorBase::from_data(layout.shape(), Cow::Owned(self.into_data())) + } +}