Skip to content

Commit

Permalink
Merge pull request #387 from igor-yusupov/ops_int
Browse files Browse the repository at this point in the history
int8 and uint8 support in Pad and other ops

Support u8 and i8 tensors in:

- Cast
- Gather
- GatherElements
- GatherND
- ScatterElements
- ScatterND
- Expand
- Flatten
- Reshape
- Squeeze
- Transpose
- Unsqueeze
- Pad
  • Loading branch information
robertknight authored Oct 17, 2024
2 parents b557d7a + 34afdab commit b76661c
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 20 deletions.
19 changes: 16 additions & 3 deletions src/ops/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,27 @@ fn cast(pool: &TensorPool, input: Input, dtype: DataType) -> Result<Output, OpEr
DataType::Int32 => match input {
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()),
},
DataType::Float => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()),
},
DataType::Int8 => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()),
},
DataType::UInt8 => match input {
Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()),
Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()),
},
_ => Err(OpError::UnsupportedValue("Unsupported cast")),
}
}

Expand Down
30 changes: 27 additions & 3 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl Operator for Gather {
Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
}
}
}
Expand Down Expand Up @@ -238,7 +238,12 @@ impl Operator for GatherElements {
Input::FloatTensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
}
}
}
Expand Down Expand Up @@ -336,7 +341,12 @@ impl Operator for GatherND {
Input::FloatTensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
}
}
}
Expand Down Expand Up @@ -451,6 +461,14 @@ impl Operator for ScatterElements {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
Expand Down Expand Up @@ -547,6 +565,12 @@ impl Operator for ScatterND {
(Input::FloatTensor(data), Input::FloatTensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
Expand Down
51 changes: 40 additions & 11 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ impl Operator for Expand {
match input {
Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int32Tensor(input) => expand(pool, input, &shape).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::UInt8Tensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int8Tensor(input) => expand(pool, input, &shape).into_op_result(),
}
}

Expand All @@ -122,7 +123,8 @@ impl Operator for Expand {
let output: Output = match input {
Output::FloatTensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::Int32Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
_ => return Err(OpError::UnsupportedType),
Output::Int8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::UInt8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
};
Ok(output)
}
Expand Down Expand Up @@ -172,7 +174,8 @@ impl Operator for Flatten {
match input {
Input::FloatTensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::Int32Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::UInt8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
}
}

Expand All @@ -195,7 +198,14 @@ impl Operator for Flatten {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
}
}
}
Expand Down Expand Up @@ -314,7 +324,8 @@ impl Operator for Reshape {
match input {
Input::Int32Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::FloatTensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::UInt8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
}
}

Expand All @@ -340,7 +351,14 @@ impl Operator for Reshape {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
}
}
}
Expand Down Expand Up @@ -449,7 +467,8 @@ impl Operator for Squeeze {
match input {
Input::FloatTensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::Int32Tensor(t) => squeeze(pool, t, axes).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::UInt8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
}
}

Expand All @@ -475,7 +494,14 @@ impl Operator for Squeeze {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
_ => Err(OpError::UnsupportedType),
Output::UInt8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
Output::Int8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
}
}
}
Expand Down Expand Up @@ -519,7 +545,8 @@ impl Operator for Transpose {
match input {
Input::FloatTensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::Int32Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::UInt8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
}
}
}
Expand Down Expand Up @@ -577,7 +604,8 @@ impl Operator for Unsqueeze {
match input {
Input::FloatTensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::Int32Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::UInt8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
}
}

Expand All @@ -597,7 +625,8 @@ impl Operator for Unsqueeze {
match output {
Output::FloatTensor(t) => unsqueeze_in_place(t, &axes).map(Output::FloatTensor),
Output::Int32Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int32Tensor),
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int8Tensor),
Output::UInt8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::UInt8Tensor),
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/ops/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,18 @@ impl Operator for Pad {
let const_val = inputs.get_as_scalar::<i32>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::Int8Tensor(t) => {
let const_val = inputs.get_as_scalar::<i8>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::UInt8Tensor(t) => {
let const_val = inputs.get_as_scalar::<u8>(2)?.unwrap_or(0);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
Input::FloatTensor(t) => {
let const_val = inputs.get_as_scalar::<f32>(2)?.unwrap_or(0.);
pad(pool, t, &pads, self.mode, const_val).into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/ops/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ impl Operator for Slice {
Input::Int32Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
_ => Err(OpError::UnsupportedType),
Input::Int8Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
Input::UInt8Tensor(input) => {
slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into())
}
};
result.into_op_result()
}
Expand Down Expand Up @@ -168,7 +173,14 @@ impl Operator for Slice {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
_ => Err(OpError::UnsupportedType),
Output::Int8Tensor(mut output) => {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
slice_in_place(&mut output, &starts, &ends, axes.as_ref())?;
Ok(output.into())
}
}
}
}
Expand Down

0 comments on commit b76661c

Please sign in to comment.