diff --git a/tensorflow/lite/micro/kernels/activations.cc b/tensorflow/lite/micro/kernels/activations.cc index 1086325ca84..6772e4765af 100644 --- a/tensorflow/lite/micro/kernels/activations.cc +++ b/tensorflow/lite/micro/kernels/activations.cc @@ -75,6 +75,7 @@ void* Relu6Init(TfLiteContext* context, const char* buffer, size_t length) { TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); + const Relu6OpData& data = *(static_cast(node->user_data)); const TfLiteEvalTensor* input = @@ -92,11 +93,19 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } case kTfLiteInt8: { - Relu6Quantized(data.zero_int8, data.six_int8, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + Relu6Quantized(data.zero, data.six, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteInt16: { + Relu6Quantized(data.zero, data.six, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); return kTfLiteOk; } default: { diff --git a/tensorflow/lite/micro/kernels/activations.h b/tensorflow/lite/micro/kernels/activations.h index e953f0e0daf..eaf93c2df26 100644 --- a/tensorflow/lite/micro/kernels/activations.h +++ b/tensorflow/lite/micro/kernels/activations.h @@ -32,8 +32,8 @@ struct ReluOpData { }; struct Relu6OpData { - int8_t six_int8; - int8_t zero_int8; + int32_t six; + int32_t zero; }; void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape, @@ -50,9 +50,17 @@ void ReluFloat(const RuntimeShape& input_shape, const float* input_data, void Relu6Float(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data); -void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape, - const int8_t* input_data, const RuntimeShape& output_shape, - int8_t* output_data); +template +void Relu6Quantized(T lower, T upper, const RuntimeShape& input_shape, + const T* input_data, const RuntimeShape& output_shape, + T* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const T val = input_data[i]; + const T clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; + } +} TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node); diff --git a/tensorflow/lite/micro/kernels/activations_common.cc b/tensorflow/lite/micro/kernels/activations_common.cc index 2ec3a1bf59f..90062447791 100644 --- a/tensorflow/lite/micro/kernels/activations_common.cc +++ b/tensorflow/lite/micro/kernels/activations_common.cc @@ -102,17 +102,6 @@ void Relu6Float(const RuntimeShape& input_shape, const float* input_data, } } -void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape, - const int8_t* input_data, const RuntimeShape& output_shape, - int8_t* output_data) { - const int flat_size = MatchingFlatSize(input_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - const int8_t val = input_data[i]; - const int8_t clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); ReluOpData* data = static_cast(node->user_data); @@ -137,6 +126,7 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); + Relu6OpData* data = static_cast(node->user_data); MicroContext* micro_context = GetMicroContext(context); @@ -145,9 +135,15 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input != nullptr); if (input->type == kTfLiteInt8) { - data->six_int8 = FloatToQuantizedType(6.0f, input->params.scale, - input->params.zero_point); - data->zero_int8 = input->params.zero_point; + data->zero = input->params.zero_point; + data->six = FloatToQuantizedType(6.0f, input->params.scale, + input->params.zero_point); + TF_LITE_ENSURE(context, data->six >= INT8_MIN && data->six <= INT8_MAX); + } else if (input->type == kTfLiteInt16) { + data->zero = input->params.zero_point; + data->six = FloatToQuantizedType(6.0f, input->params.scale, + input->params.zero_point); + TF_LITE_ENSURE(context, data->six >= INT16_MIN && data->six <= INT16_MAX); } micro_context->DeallocateTempTfLiteTensor(input); diff --git a/tensorflow/lite/micro/kernels/activations_test.cc b/tensorflow/lite/micro/kernels/activations_test.cc index 25402a80284..479668a164b 100644 --- a/tensorflow/lite/micro/kernels/activations_test.cc +++ b/tensorflow/lite/micro/kernels/activations_test.cc @@ -169,6 +169,46 @@ void TestRelu6Int8(int* input_dims_data, const float* input_data, } } +void TestRelu6Int16(int* input_dims_data, const float* input_data, + int16_t* input_data_quantized, const float input_scale, + const int input_zero_point, const float* golden, + int16_t* golden_quantized, int* output_dims_data, + const float output_scale, const int output_zero_point, + int16_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_elements_count = ElementCount(*output_dims); + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_data_quantized, input_dims, + input_scale, input_zero_point), + CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TFLMRegistration registration = Register_RELU6(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); + + for (int i = 0; i < output_elements_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); + } +} + } // namespace } // namespace testing } // namespace tflite @@ -247,4 +287,26 @@ TF_LITE_MICRO_TEST(SimpleRelu6TestInt8) { output_zero_point, output_data); } +TF_LITE_MICRO_TEST(SimpleRelu6TestInt16) { + const int elements_count = 10; + + int input_shape[] = {2, 1, 5}; + const float input_data[] = {4, 5, 6, 7, 8, -1, -2, -3, -4, -5}; + int16_t input_quantized[elements_count]; + int output_shape[] = {2, 1, 5}; + const float golden[] = {4, 5, 6, 6, 6, 0, 0, 0, 0, 0}; + int16_t golden_quantized[elements_count]; + int16_t output_data[elements_count]; + + const float input_scale = 0.5f; + const int input_zero_point = 0; + const float output_scale = 0.5f; + const int output_zero_point = 0; + + tflite::testing::TestRelu6Int16(input_shape, input_data, input_quantized, + input_scale, input_zero_point, golden, + golden_quantized, output_shape, output_scale, + output_zero_point, output_data); +} + TF_LITE_MICRO_TESTS_END