diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 6c63f9d76b2..f67f7dd37ed 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -111,6 +111,7 @@ PythonOpsResolver::PythonOpsResolver() { AddSquaredDifference(); AddSqueeze(); AddStridedSlice(); + AddStacker(); AddSub(); AddSum(); AddSvdf(); diff --git a/signal/micro/kernels/BUILD b/signal/micro/kernels/BUILD index f3ec739f425..6357db5b855 100644 --- a/signal/micro/kernels/BUILD +++ b/signal/micro/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( "framer.cc", "overlap_add.cc", "rfft.cc", + "stacker.cc", "window.cc", ], hdrs = [ @@ -167,3 +168,29 @@ cc_test( "//tensorflow/lite/micro/testing:micro_test", ], ) + +cc_library( + name = "stacker_flexbuffers_generated_data", + srcs = [ + "stacker_flexbuffers_generated_data.cc", + ], + hdrs = [ + "stacker_flexbuffers_generated_data.h", + ], +) + +cc_test( + name = "stacker_test", + srcs = [ + "stacker_test.cc", + ], + deps = [ + ":register_signal_ops", + ":stacker_flexbuffers_generated_data", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/kernels:kernel_runner", + "//tensorflow/lite/micro/testing:micro_test", + ], +) diff --git a/signal/micro/kernels/stacker.cc b/signal/micro/kernels/stacker.cc new file mode 100644 index 00000000000..42a2ee62fe6 --- /dev/null +++ b/signal/micro/kernels/stacker.cc @@ -0,0 +1,176 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "signal/src/circular_buffer.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/flatbuffer_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; +constexpr int kOutputValidTensor = 1; + +// Indices into the init flexbuffer's vector. +// The parameter's name is in the comment that follows. +// Elements in the vectors are ordered alphabetically by parameter name. +constexpr int kNumChannelsIndex = 0; // 'num_channels' +constexpr int kStackerLeftContextIndex = 1; // 'stacker_left_context' +constexpr int kStackerRightContextIndex = 2; // 'stacker_right_context' +constexpr int kStackerStepIndex = 3; // 'stacker_step' + +struct TFLMSignalStackerParams { + int32_t num_channels; + int32_t stacker_left_context; + int32_t stacker_right_context; + int32_t stacker_step; + + size_t buffer_size; + size_t step_size; + bool stacker_has_first_frame; + + int8_t* state; + tflm_signal::CircularBuffer* circular_buffer; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast(buffer); + + auto* params = + static_cast(context->AllocatePersistentBuffer( + context, sizeof(TFLMSignalStackerParams))); + if (params == nullptr) { + return nullptr; + } + + tflite::FlexbufferWrapper fbw(buffer_t, length); + params->num_channels = fbw.ElementAsInt32(kNumChannelsIndex); + params->stacker_left_context = fbw.ElementAsInt32(kStackerLeftContextIndex); + params->stacker_right_context = fbw.ElementAsInt32(kStackerRightContextIndex); + params->stacker_step = fbw.ElementAsInt32(kStackerStepIndex); + + params->buffer_size = + params->num_channels * + (params->stacker_left_context + params->stacker_right_context + 1); + params->step_size = params->num_channels * params->stacker_step; + params->stacker_has_first_frame = false; + + size_t state_size = + tflm_signal::CircularBufferGetNeededMemory(params->buffer_size); + params->state = static_cast( + context->AllocatePersistentBuffer(context, sizeof(int8_t) * state_size)); + + if (params->state == nullptr) { + return nullptr; + } + + params->circular_buffer = tflm_signal::CircularBufferInit( + params->buffer_size, params->state, state_size); + return params; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + MicroContext* micro_context = GetMicroContext(context); + + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteTensor* output_valid = + micro_context->AllocateTempOutputTensor(node, kOutputValidTensor); + TF_LITE_ENSURE(context, output_valid != nullptr); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_valid), 0); + + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16); + TF_LITE_ENSURE_TYPES_EQ(context, output_valid->type, kTfLiteBool); + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + micro_context->DeallocateTempTfLiteTensor(output_valid); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE(context, params != nullptr); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + TfLiteEvalTensor* output_valid = + tflite::micro::GetEvalOutput(context, node, kOutputValidTensor); + + const int16_t* input_data = tflite::micro::GetTensorData(input); + + tflm_signal::CircularBufferWrite(params->circular_buffer, input_data, + params->num_channels); + + // The first frame is replicated an extra left_context times to pad. + if (params->stacker_has_first_frame == false) { + tflm_signal::CircularBufferExtend(params->circular_buffer, + params->num_channels, + params->stacker_left_context); + params->stacker_has_first_frame = true; + } + + int16_t* output_data = tflite::micro::GetTensorData(output); + bool* output_valid_data = tflite::micro::GetTensorData(output_valid); + if (tflm_signal::CircularBufferAvailable(params->circular_buffer) >= + params->buffer_size) { + tflm_signal::CircularBufferGet(params->circular_buffer, params->buffer_size, + output_data); + tflm_signal::CircularBufferDiscard(params->circular_buffer, + params->step_size); + *output_valid_data = true; + } else { + *output_valid_data = false; + } + return kTfLiteOk; +} + +void Reset(TfLiteContext* context, void* buffer) { + auto* params = static_cast(buffer); + tflm_signal::CircularBufferReset(params->circular_buffer); + params->stacker_has_first_frame = false; +} + +} // namespace + +namespace tflm_signal { +TFLMRegistration* Register_STACKER() { + static TFLMRegistration r = + tflite::micro::RegisterOp(Init, Prepare, Eval, /*Free*/ nullptr, Reset); + return &r; +} +} // namespace tflm_signal + +} // namespace tflite diff --git a/signal/micro/kernels/stacker_flexbuffers_generated_data.cc b/signal/micro/kernels/stacker_flexbuffers_generated_data.cc new file mode 100644 index 00000000000..654e4b7f65b --- /dev/null +++ b/signal/micro/kernels/stacker_flexbuffers_generated_data.cc @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is generated. See: +// tensorflow/lite/micro/kernels/test_data_generation/README.md + +#include "signal/micro/kernels/stacker_flexbuffers_generated_data.h" + +const int g_gen_data_size_stacker_3_channels_step_1 = 88; +const unsigned char g_gen_data_stacker_3_channels_step_1[] = { + 0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, + 0x73, 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x6c, + 0x65, 0x66, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x72, 0x69, + 0x67, 0x68, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x73, 0x74, + 0x65, 0x70, 0x00, 0x04, 0x46, 0x3a, 0x26, 0x11, 0x04, 0x01, 0x04, + 0x03, 0x01, 0x00, 0x01, 0x04, 0x04, 0x04, 0x04, 0x08, 0x24, 0x01, +}; +const int g_gen_data_size_stacker_10_channels_step_2 = 88; +const unsigned char g_gen_data_stacker_10_channels_step_2[] = { + 0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, + 0x73, 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x6c, + 0x65, 0x66, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x72, 0x69, + 0x67, 0x68, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x00, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x65, 0x72, 0x5f, 0x73, 0x74, + 0x65, 0x70, 0x00, 0x04, 0x46, 0x3a, 0x26, 0x11, 0x04, 0x01, 0x04, + 0x0a, 0x01, 0x00, 0x02, 0x04, 0x04, 0x04, 0x04, 0x08, 0x24, 0x01, +}; diff --git a/signal/micro/kernels/stacker_flexbuffers_generated_data.h b/signal/micro/kernels/stacker_flexbuffers_generated_data.h new file mode 100644 index 00000000000..47a38277ba3 --- /dev/null +++ b/signal/micro/kernels/stacker_flexbuffers_generated_data.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_STACKER_FLEXBUFFERS_DATA_H_ +#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_STACKER_FLEXBUFFERS_DATA_H_ + +extern const int g_gen_data_size_stacker_3_channels_step_1; +extern const unsigned char g_gen_data_stacker_3_channels_step_1[]; + +extern const int g_gen_data_size_stacker_10_channels_step_2; +extern const unsigned char g_gen_data_stacker_10_channels_step_2[]; + +#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_STACKER_FLEXBUFFERS_DATA_H_ diff --git a/signal/micro/kernels/stacker_test.cc b/signal/micro/kernels/stacker_test.cc new file mode 100644 index 00000000000..d236c7539a8 --- /dev/null +++ b/signal/micro/kernels/stacker_test.cc @@ -0,0 +1,243 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "signal/micro/kernels/stacker_flexbuffers_generated_data.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace { + +constexpr int kInputsSize = 1; +constexpr int kOutputsSize = 2; +constexpr int kTensorsSize = kInputsSize + kOutputsSize; + +class StackerKernelRunner { + public: + StackerKernelRunner(int* input_dims_data, const int16_t* input_data, + int* output_dims_data, int16_t* output_data, + int* output_ready_dims_data, bool* ouput_ready_data) + : tensors_{testing::CreateTensor( + input_data, + tflite::testing::IntArrayFromInts(input_dims_data)), + testing::CreateTensor( + output_data, + tflite::testing::IntArrayFromInts(output_dims_data)), + testing::CreateTensor( + ouput_ready_data, + testing::IntArrayFromInts(output_ready_dims_data))}, + inputs_array_{testing::IntArrayFromInts(inputs_array_data_)}, + outputs_array_{testing::IntArrayFromInts(outputs_array_data_)}, + kernel_runner_{*registration_, tensors_, kTensorsSize, + inputs_array_, outputs_array_, nullptr} {} + + micro::KernelRunner* kernel_runner() { return &kernel_runner_; } + + private: + int inputs_array_data_[2] = {1, 0}; + int outputs_array_data_[3] = {2, 1, 2}; + TfLiteTensor tensors_[kTensorsSize] = {}; + TfLiteIntArray* inputs_array_ = nullptr; + TfLiteIntArray* outputs_array_ = nullptr; + TFLMRegistration* registration_ = tflm_signal::Register_STACKER(); + micro::KernelRunner kernel_runner_; +}; + +void TestStackerInvoke(int* output_dims_data, int16_t* output_data, + bool* ouput_ready_data, const int16_t* golden, + micro::KernelRunner* kernel_runner) { + TfLiteIntArray* output_dims = testing::IntArrayFromInts(output_dims_data); + + const int output_len = ElementCount(*output_dims); + + TF_LITE_MICRO_EXPECT_EQ(kernel_runner->Invoke(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(*ouput_ready_data, 1); + + for (int i = 0; i < output_len; ++i) { + TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]); + } +} + +void TestStacker(int* input_dims_data, const int16_t* input_data, + int* output_dims_data, int16_t* output_data, + int* output_ready_dims_data, bool* ouput_ready_data, + const int16_t* golden, const unsigned char* flexbuffers_data, + const unsigned int flexbuffers_data_size) { + StackerKernelRunner stacker_runner(input_dims_data, input_data, + output_dims_data, output_data, + output_ready_dims_data, ouput_ready_data); + + // TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned + // char*. This small discrepancy results in compiler warnings unless we + // reinterpret_cast right before passing in the flexbuffer bytes to the + // KernelRunner. + TF_LITE_MICRO_EXPECT_EQ(stacker_runner.kernel_runner()->InitAndPrepare( + reinterpret_cast(flexbuffers_data), + flexbuffers_data_size), + kTfLiteOk); + TestStackerInvoke(output_dims_data, output_data, ouput_ready_data, golden, + stacker_runner.kernel_runner()); +} + +// TestStackerReset() runs a test with the given inputs twice with a reset with +// the main purpose of testing the Stacker's Reset functionality. If you just +// want to make sure Stacker's Op output matches a set of golden values for an +// input use TestStacker() instead. +void TestStackerReset(int* input_dims_data, const int16_t* input_data, + int* output_dims_data, int16_t* output_data, + int* output_ready_dims_data, bool* ouput_ready_data, + const int16_t* golden, + const unsigned char* flexbuffers_data, + const unsigned int flexbuffers_data_size) { + StackerKernelRunner stacker_runner(input_dims_data, input_data, + output_dims_data, output_data, + output_ready_dims_data, ouput_ready_data); + + // TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned + // char*. This small discrepancy results in compiler warnings unless we + // reinterpret_cast right before passing in the flexbuffer bytes to the + // KernelRunner. + TF_LITE_MICRO_EXPECT_EQ(stacker_runner.kernel_runner()->InitAndPrepare( + reinterpret_cast(flexbuffers_data), + flexbuffers_data_size), + kTfLiteOk); + TestStackerInvoke(output_dims_data, output_data, ouput_ready_data, golden, + stacker_runner.kernel_runner()); + stacker_runner.kernel_runner()->Reset(); + TestStackerInvoke(output_dims_data, output_data, ouput_ready_data, golden, + stacker_runner.kernel_runner()); +} + +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(StackerTest3ChannelStep1) { + int input_shape[] = {1, 3}; + int output_shape[] = {1, 6}; + int output_ready_shape[] = {0}; + const int16_t input[] = {0x1234, 0x5678, 0x4321}; + const int16_t golden[] = {0x1234, 0x5678, 0x4321, 0x1234, 0x5678, 0x4321}; + + int16_t output[6]; + bool output_ready = false; + + tflite::TestStacker(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_3_channels_step_1, + g_gen_data_size_stacker_3_channels_step_1); +} + +TF_LITE_MICRO_TEST(StackerTest10ChannelStep2_1stTest) { + int input_shape[] = {1, 10}; + int output_shape[] = {1, 20}; + int output_ready_shape[] = {0}; + + int16_t output[20]; + bool output_ready = false; + + const int16_t input[10] = {252, 477, 1071, 166, 1022, + 312, 1171, 1586, 1491, 145}; + + const int16_t golden[] = {252, 477, 1071, 166, 1022, 312, 1171, + 1586, 1491, 145, 252, 477, 1071, 166, + 1022, 312, 1171, 1586, 1491, 145}; + tflite::TestStacker(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_10_channels_step_2, + g_gen_data_size_stacker_10_channels_step_2); +} + +TF_LITE_MICRO_TEST(StackerTest10ChannelStep2_2ndTest) { + int input_shape[] = {1, 10}; + int output_shape[] = {1, 20}; + int output_ready_shape[] = {0}; + + int16_t output[20]; + bool output_ready = false; + + const int16_t input[10] = {1060, 200, 69, 1519, 883, + 1317, 182, 724, 143, 334}; + + const int16_t golden[] = {1060, 200, 69, 1519, 883, 1317, 182, 724, 143, 334, + 1060, 200, 69, 1519, 883, 1317, 182, 724, 143, 334}; + + tflite::TestStacker(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_10_channels_step_2, + g_gen_data_size_stacker_10_channels_step_2); +} + +TF_LITE_MICRO_TEST(StackerTestReset3ChannelStep1) { + int input_shape[] = {1, 3}; + int output_shape[] = {1, 6}; + int output_ready_shape[] = {0}; + const int16_t input[] = {0x1234, 0x5678, 0x4321}; + const int16_t golden[] = {0x1234, 0x5678, 0x4321, 0x1234, 0x5678, 0x4321}; + + int16_t output[6]; + bool output_ready = false; + + tflite::TestStackerReset(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_3_channels_step_1, + g_gen_data_size_stacker_3_channels_step_1); +} + +TF_LITE_MICRO_TEST(StackerTestReset10ChannelStep2_1stTest) { + int input_shape[] = {1, 10}; + int output_shape[] = {1, 20}; + int output_ready_shape[] = {0}; + + int16_t output[20]; + bool output_ready = false; + + const int16_t input[10] = {252, 477, 1071, 166, 1022, + 312, 1171, 1586, 1491, 145}; + + const int16_t golden[] = {252, 477, 1071, 166, 1022, 312, 1171, + 1586, 1491, 145, 252, 477, 1071, 166, + 1022, 312, 1171, 1586, 1491, 145}; + tflite::TestStackerReset(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_10_channels_step_2, + g_gen_data_size_stacker_10_channels_step_2); +} + +TF_LITE_MICRO_TEST(StackerTestReset10ChannelStep2_2ndTest) { + int input_shape[] = {1, 10}; + int output_shape[] = {1, 20}; + int output_ready_shape[] = {0}; + + int16_t output[20]; + bool output_ready = false; + + const int16_t input[10] = {1060, 200, 69, 1519, 883, + 1317, 182, 724, 143, 334}; + + const int16_t golden[] = {1060, 200, 69, 1519, 883, 1317, 182, 724, 143, 334, + 1060, 200, 69, 1519, 883, 1317, 182, 724, 143, 334}; + + tflite::TestStackerReset(input_shape, input, output_shape, output, + output_ready_shape, &output_ready, golden, + g_gen_data_stacker_10_channels_step_2, + g_gen_data_size_stacker_10_channels_step_2); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 5f449b2a97c..8fbbb34c251 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -69,6 +69,11 @@ $(eval $(call microlite_test,kernel_signal_overlap_add_test,\ $(TENSORFLOW_ROOT)signal/micro/kernels/overlap_add_flexbuffers_generated_data.cc, \ $(TENSORFLOW_ROOT)signal/micro/kernels/overlap_add_flexbuffers_generated_data.h)) +$(eval $(call microlite_test,kernel_signal_stacker_test,\ + $(TENSORFLOW_ROOT)signal/micro/kernels/stacker_test.cc \ + $(TENSORFLOW_ROOT)signal/micro/kernels/stacker_flexbuffers_generated_data.cc, \ + $(TENSORFLOW_ROOT)signal/micro/kernels/stacker_flexbuffers_generated_data.h)) + $(eval $(call microlite_test,kernel_signal_window_test,\ $(TENSORFLOW_ROOT)signal/micro/kernels/window_test.cc \ $(TENSORFLOW_ROOT)signal/micro/kernels/window_flexbuffers_generated_data.cc, \ diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 15a1146d6e3..34822312ece 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -137,6 +137,7 @@ namespace tflm_signal { TFLMRegistration* Register_DELAY(); TFLMRegistration* Register_FRAMER(); TFLMRegistration* Register_OVERLAP_ADD(); +TFLMRegistration* Register_STACKER(); TFLMRegistration* Register_WINDOW(); } // namespace tflm_signal diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 76f04c39a63..df80079898a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -538,6 +538,11 @@ class MicroMutableOpResolver : public MicroOpResolver { ParseStridedSlice); } + TfLiteStatus AddStacker() { + // TODO(b/286250473): change back name to "Stacker" and remove namespace + return AddCustom("SignalStacker", tflite::tflm_signal::Register_STACKER()); + } + TfLiteStatus AddSub() { return AddBuiltin(BuiltinOperator_SUB, tflite::Register_SUB(), ParseSub); } diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 1da2b85c81c..e002cf1a284 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -315,6 +315,7 @@ MICROLITE_CC_KERNEL_SRCS := \ $(TENSORFLOW_ROOT)signal/micro/kernels/delay.cc \ $(TENSORFLOW_ROOT)signal/micro/kernels/framer.cc \ $(TENSORFLOW_ROOT)signal/micro/kernels/rfft.cc \ +$(TENSORFLOW_ROOT)signal/micro/kernels/stacker.cc \ $(TENSORFLOW_ROOT)signal/micro/kernels/overlap_add.cc \ $(TENSORFLOW_ROOT)signal/micro/kernels/window.cc \ $(TENSORFLOW_ROOT)signal/src/circular_buffer.cc \