Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
turbotoribio committed Jul 11, 2023
1 parent 8c2b283 commit 4947dad
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
6 changes: 6 additions & 0 deletions signal/micro/kernels/overlap_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
context->AllocatePersistentBuffer(context,
sizeof(TFLMSignalOverlapAddParams<T>)));

if (params == nullptr) {
return nullptr;
}

tflite::FlexbufferWrapper fbw(buffer_t, length);
params->type = typeToTfLiteType<T>();
params->frame_step = fbw.ElementAsInt32(kFrameStepIndex);
Expand Down Expand Up @@ -101,6 +105,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_shape.FlatSize() / (params->frame_size * params->n_frames);
params->state_buffers = static_cast<T**>(context->AllocatePersistentBuffer(
context, params->outer_dims * sizeof(T*)));
TF_LITE_ENSURE(context, params != nullptr);

for (int i = 0; i < params->outer_dims; i++) {
params->state_buffers[i] =
static_cast<T*>(context->AllocatePersistentBuffer(
Expand Down
62 changes: 33 additions & 29 deletions signal/micro/kernels/overlap_add_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

namespace {
namespace tflite {

constexpr int kFrameStepIndex = 1;
constexpr int kInputsSize = 1;
Expand All @@ -32,33 +32,33 @@ class OverlapAddKernelRunner {
public:
OverlapAddKernelRunner(int* input_dims_data, T* input_data,
int* output_dims_data, T* output_data)
: inputs_array_{tflite::testing::IntArrayFromInts(inputs_array_data_)},
outputs_array_{tflite::testing::IntArrayFromInts(outputs_array_data_)} {
tensors_[0] = tflite::testing::CreateTensor(
input_data, tflite::testing::IntArrayFromInts(input_dims_data));
: inputs_array_{testing::IntArrayFromInts(inputs_array_data_)},
outputs_array_{testing::IntArrayFromInts(outputs_array_data_)} {
tensors_[0] = testing::CreateTensor(
input_data, testing::IntArrayFromInts(input_dims_data));

tensors_[1] = tflite::testing::CreateTensor(
output_data, tflite::testing::IntArrayFromInts(output_dims_data));
output_data, testing::IntArrayFromInts(output_dims_data));

registration_ = tflite::tflm_signal::Register_OVERLAP_ADD();
registration_ = tflm_signal::Register_OVERLAP_ADD();

// go/tflm-static-cleanups for reasoning new is being used like this
kernel_runner_ = new (kernel_runner_buffer) tflite::micro::KernelRunner(
*registration_, tensors_, kTensorsSize, inputs_array_, outputs_array_,
/*builtin_data=*/nullptr);
}

tflite::micro::KernelRunner& GetKernelRunner() { return *kernel_runner_; }
micro::KernelRunner& GetKernelRunner() { return *kernel_runner_; }

private:
uint8_t kernel_runner_buffer[sizeof(tflite::micro::KernelRunner)];
uint8_t kernel_runner_buffer[sizeof(micro::KernelRunner)];
int inputs_array_data_[kInputsSize + 1] = {1, 0};
int outputs_array_data_[kOutputsSize + 1] = {1, 1};
TfLiteTensor tensors_[kTensorsSize] = {};
TfLiteIntArray* inputs_array_ = nullptr;
TfLiteIntArray* outputs_array_ = nullptr;
TFLMRegistration* registration_ = nullptr;
tflite::micro::KernelRunner* kernel_runner_ = nullptr;
micro::KernelRunner* kernel_runner_ = nullptr;
};

// We can use any of the templated types here - int16_t was picked arbitrarily
Expand Down Expand Up @@ -149,7 +149,7 @@ void TestOverlapAddReset(int* input_dims_data, T* input_data,
&overlap_add_runner->GetKernelRunner());
}

} // namespace
} // namespace tflite

TF_LITE_MICRO_TESTS_BEGIN

Expand All @@ -164,10 +164,11 @@ TF_LITE_MICRO_TEST(OverlapAddTestInt16) {
63, 71, 52, 1, -17, 32};
const int16_t golden_output[] = {125, 988, -767, -140};

TestOverlapAdd(input_dims_data, input_data, output_dims_data, golden_input,
golden_output, sizeof(golden_output) / sizeof(int16_t),
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, &output_data);
tflite::TestOverlapAdd(input_dims_data, input_data, output_dims_data,
golden_input, golden_output,
sizeof(golden_output) / sizeof(int16_t),
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, &output_data);
}

TF_LITE_MICRO_TEST(OverlapAddTestFloat) {
Expand All @@ -181,10 +182,11 @@ TF_LITE_MICRO_TEST(OverlapAddTestFloat) {
6.3, 7.1, 5.2, 0.1, -1.7, 3.2};
const float golden_output[] = {12.5, 98.8, -76.7, -14.0};

TestOverlapAdd(input_dims_data, input_data, output_dims_data, golden_input,
golden_output, sizeof(golden_output) / sizeof(float),
g_gen_data_overlap_add_float,
g_gen_data_size_overlap_add_float, &output_data);
tflite::TestOverlapAdd(input_dims_data, input_data, output_dims_data,
golden_input, golden_output,
sizeof(golden_output) / sizeof(float),
g_gen_data_overlap_add_float,
g_gen_data_size_overlap_add_float, &output_data);
}

TF_LITE_MICRO_TEST(OverlapAddTestNframes4Int16) {
Expand All @@ -201,9 +203,10 @@ TF_LITE_MICRO_TEST(OverlapAddTestNframes4Int16) {

const int kIters =
sizeof(golden_input) / kInputSize / kNFrames / sizeof(int16_t);
TestOverlapAdd(input_dims_data, input_data, output_dims_data, golden_input,
golden_output, kIters, g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
tflite::TestOverlapAdd(input_dims_data, input_data, output_dims_data,
golden_input, golden_output, kIters,
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
}

TF_LITE_MICRO_TEST(OverlapAddTestNframes4OuterDims4Int16) {
Expand All @@ -220,9 +223,10 @@ TF_LITE_MICRO_TEST(OverlapAddTestNframes4OuterDims4Int16) {

const int kIters =
sizeof(golden_input) / kInputSize / kNFrames / sizeof(int16_t);
TestOverlapAdd(input_dims_data, input_data, output_dims_data, golden_input,
golden_output, kIters, g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
tflite::TestOverlapAdd(input_dims_data, input_data, output_dims_data,
golden_input, golden_output, kIters,
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
}

TF_LITE_MICRO_TEST(testReset) {
Expand All @@ -239,10 +243,10 @@ TF_LITE_MICRO_TEST(testReset) {

const int kIters =
sizeof(golden_input) / kInputSize / kNFrames / sizeof(int16_t);
TestOverlapAddReset(input_dims_data, input_data, output_dims_data,
golden_input, golden_output, kIters,
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
tflite::TestOverlapAddReset(input_dims_data, input_data, output_dims_data,
golden_input, golden_output, kIters,
g_gen_data_overlap_add_int16,
g_gen_data_size_overlap_add_int16, output_data);
}

TF_LITE_MICRO_TESTS_END

0 comments on commit 4947dad

Please sign in to comment.