Skip to content

Commit

Permalink
Conditionalize data member access on HEXAGON define
Browse files Browse the repository at this point in the history
  • Loading branch information
rascani committed Sep 26, 2024
1 parent 7b1a6f4 commit 62255cc
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions tensorflow/lite/micro/kernels/fully_connected_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,39 @@ TfLiteStatus CalculateOpDataFullyConnected(
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
OpDataFullyConnected* data) {
#ifndef HEXAGON
data->is_per_channel = false;
#endif

if (data_type == kTfLiteFloat32) {
return kTfLiteOk;
}

bool is_per_channel = false;
if (filter->quantization.type == kTfLiteAffineQuantization &&
filter->quantization.params != nullptr) {
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
data->is_per_channel = affine_quantization->scale->size > 1;
is_per_channel = affine_quantization->scale->size > 1;
}

if (data->is_per_channel) {
if (is_per_channel) {
// Hexagon currently does not support per-channel fully connected, and the
// existing hexagon support library is intolerant of data members being added to
// OpDataFullyConnected. As such, we have to be careful not to reference newer
// data members. This is why we use a local variable is_per_channel in common
// code, and only reference the data->is_per_channel in non-HEXAGON code.
#ifdef HEXAGON
TF_LITE_ENSURE_MSG(
context, !is_per_channel,
"FullyConnected per-channel quantization not yet supported on Hexagon. "
"Please set converter._experimental_disable_per_channel_quantization_"
"for_dense_layers = True.");
#else
data->is_per_channel = is_per_channel;
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
Expand Down Expand Up @@ -111,6 +127,7 @@ TfLiteStatus CalculateOpDataFullyConnected(
data->per_channel_output_multiplier[i] = significand;
data->per_channel_output_shift[i] = channel_shift;
}
#endif
} else {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
Expand Down

0 comments on commit 62255cc

Please sign in to comment.