Skip to content

Commit

Permalink
Update to rten v0.11
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Jul 5, 2024
1 parent fe6de19 commit 249d077
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ members = [
]

[workspace.dependencies]
rten = { version = "0.10.0" }
rten-imageproc = { version = "0.10.0" }
rten-tensor = { version = "0.10.0" }
rten = { version = "0.11.0" }
rten-imageproc = { version = "0.11.0" }
rten-tensor = { version = "0.11.0" }
10 changes: 5 additions & 5 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl OcrEngine {
mod tests {
use std::error::Error;

use rten::model_builder::{ModelBuilder, OpType};
use rten::model_builder::{ModelBuilder, ModelFormat, OpType};
use rten::ops::{MaxPool, Transpose};
use rten::Dimension;
use rten::Model;
Expand Down Expand Up @@ -240,7 +240,7 @@ mod tests {
/// Takes a CHW input tensor with values in `[-0.5, 0.5]` and adds a +0.5
/// bias to produce an output "probability map".
fn fake_detection_model() -> Model {
let mut mb = ModelBuilder::new();
let mut mb = ModelBuilder::new(ModelFormat::V1);
let input_id = mb.add_value(
"input",
Some(&[
Expand All @@ -258,7 +258,7 @@ mod tests {
mb.add_output(output_id);

let bias = Tensor::from_scalar(0.5);
let bias_id = mb.add_float_constant(&bias);
let bias_id = mb.add_constant(bias.view());
mb.add_operator(
"add",
OpType::Add,
Expand All @@ -277,7 +277,7 @@ mod tests {
/// log-probability of each class label. In this fake we just re-interpret
/// each column of the input as a one-hot vector of probabilities.
fn fake_recognition_model() -> Model {
let mut mb = ModelBuilder::new();
let mut mb = ModelBuilder::new(ModelFormat::V1);
let input_id = mb.add_value(
"input",
Some(&[
Expand All @@ -304,7 +304,7 @@ mod tests {

// Squeeze to remove the channel dim: NCHW/4 => NHW/4
let squeeze_axes = Tensor::from_vec(vec![1]);
let squeeze_axes_id = mb.add_int_constant(&squeeze_axes);
let squeeze_axes_id = mb.add_constant(squeeze_axes.view());
let squeeze_out = mb.add_value("squeeze_out", None);
mb.add_operator(
"squeeze",
Expand Down
6 changes: 5 additions & 1 deletion ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,11 @@ impl TextRecognizer {
let input: Tensor<f32> = input.into();
let [output] = self
.model
.run_n(&[(self.input_id, (&input).into())], [self.output_id], None)
.run_n(
vec![(self.input_id, (&input).into())],
[self.output_id],
None,
)
.map_err(|err| ModelRunError::RunFailed(err.into()))?;
let mut rec_sequence: NdTensor<f32, 3> =
output.try_into().map_err(|_| ModelRunError::WrongOutput)?;
Expand Down

0 comments on commit 249d077

Please sign in to comment.