Skip to content

Commit

Permalink
Add int16 support to OP_RELU6 (#2682)
Browse files Browse the repository at this point in the history
This PR adds int16 support to OP_RELU6. This is already supported in TF Lite.

bug=#2681
  • Loading branch information
andresovela authored Sep 26, 2024
1 parent bf2ba11 commit ef2179c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 24 deletions.
19 changes: 14 additions & 5 deletions tensorflow/lite/micro/kernels/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void* Relu6Init(TfLiteContext* context, const char* buffer, size_t length) {

TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);

const Relu6OpData& data = *(static_cast<const Relu6OpData*>(node->user_data));

const TfLiteEvalTensor* input =
Expand All @@ -92,11 +93,19 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
case kTfLiteInt8: {
Relu6Quantized(data.zero_int8, data.six_int8,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
Relu6Quantized<int8_t>(data.zero, data.six,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
case kTfLiteInt16: {
Relu6Quantized<int16_t>(data.zero, data.six,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
}
default: {
Expand Down
18 changes: 13 additions & 5 deletions tensorflow/lite/micro/kernels/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct ReluOpData {
};

struct Relu6OpData {
int8_t six_int8;
int8_t zero_int8;
int32_t six;
int32_t zero;
};

void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape,
Expand All @@ -50,9 +50,17 @@ void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data);

void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& output_shape,
int8_t* output_data);
template <typename T>
void Relu6Quantized(T lower, T upper, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const T val = input_data[i];
const T clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}

TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node);

Expand Down
24 changes: 10 additions & 14 deletions tensorflow/lite/micro/kernels/activations_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,6 @@ void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
}
}

void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& output_shape,
int8_t* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const int8_t val = input_data[i];
const int8_t clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}

TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
Expand All @@ -137,6 +126,7 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {

TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);

Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);

MicroContext* micro_context = GetMicroContext(context);
Expand All @@ -145,9 +135,15 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input != nullptr);

if (input->type == kTfLiteInt8) {
data->six_int8 = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
input->params.zero_point);
data->zero_int8 = input->params.zero_point;
data->zero = input->params.zero_point;
data->six = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
input->params.zero_point);
TF_LITE_ENSURE(context, data->six >= INT8_MIN && data->six <= INT8_MAX);
} else if (input->type == kTfLiteInt16) {
data->zero = input->params.zero_point;
data->six = FloatToQuantizedType<int16_t>(6.0f, input->params.scale,
input->params.zero_point);
TF_LITE_ENSURE(context, data->six >= INT16_MIN && data->six <= INT16_MAX);
}

micro_context->DeallocateTempTfLiteTensor(input);
Expand Down
62 changes: 62 additions & 0 deletions tensorflow/lite/micro/kernels/activations_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,46 @@ void TestRelu6Int8(int* input_dims_data, const float* input_data,
}
}

void TestRelu6Int16(int* input_dims_data, const float* input_data,
int16_t* input_data_quantized, const float input_scale,
const int input_zero_point, const float* golden,
int16_t* golden_quantized, int* output_dims_data,
const float output_scale, const int output_zero_point,
int16_t* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_elements_count = ElementCount(*output_dims);
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input_data, input_data_quantized, input_dims,
input_scale, input_zero_point),
CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point),
};

int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

const TFLMRegistration registration = Register_RELU6();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);

TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());

Quantize(golden, golden_quantized, output_elements_count, output_scale,
output_zero_point);

for (int i = 0; i < output_elements_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]);
}
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -247,4 +287,26 @@ TF_LITE_MICRO_TEST(SimpleRelu6TestInt8) {
output_zero_point, output_data);
}

TF_LITE_MICRO_TEST(SimpleRelu6TestInt16) {
const int elements_count = 10;

int input_shape[] = {2, 1, 5};
const float input_data[] = {4, 5, 6, 7, 8, -1, -2, -3, -4, -5};
int16_t input_quantized[elements_count];
int output_shape[] = {2, 1, 5};
const float golden[] = {4, 5, 6, 6, 6, 0, 0, 0, 0, 0};
int16_t golden_quantized[elements_count];
int16_t output_data[elements_count];

const float input_scale = 0.5f;
const int input_zero_point = 0;
const float output_scale = 0.5f;
const int output_zero_point = 0;

tflite::testing::TestRelu6Int16(input_shape, input_data, input_quantized,
input_scale, input_zero_point, golden,
golden_quantized, output_shape, output_scale,
output_zero_point, output_data);
}

TF_LITE_MICRO_TESTS_END

0 comments on commit ef2179c

Please sign in to comment.