diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 323e391df4..1f9eba94d4 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -455,10 +455,6 @@ def valid_predecessor_op_types(self): def _check_compatibility(self): # Gather parameters to check if self._q_node.op_type == "Quant": - q_inst = getCustomOp(self._q_node) - signed = q_inst.get_nodeattr("signed") - if not signed: - raise ValueError("FINN only supports signed Quant nodes for identity activations.") if not self._model.get_initializer(self._q_node.input[2]) == 0: raise ValueError( "Only Quant nodes with zero-point == 0 " @@ -480,6 +476,7 @@ def _calculate_act_bias(self): if self._q_node.op_type == "Quant": bit_width = self._model.get_initializer(self._q_node.input[3]) narrow = q_inst.get_nodeattr("narrow") + signed = q_inst.get_nodeattr("signed") elif self._q_node.op_type == "BipolarQuant": bit_width = 1.0 else: @@ -490,10 +487,13 @@ def _calculate_act_bias(self): if bit_width == 1.0: bias = np.array([-0.5], dtype=np_default_dtype) else: - if narrow: - min_non_scaled_val = -(2 ** (bit_width - 1) - 1) + if not signed: + min_non_scaled_val = 0 else: - min_non_scaled_val = -(2 ** (bit_width - 1)) + if narrow: + min_non_scaled_val = -(2 ** (bit_width - 1) - 1) + else: + min_non_scaled_val = -(2 ** (bit_width - 1)) bias = np.array([min_non_scaled_val], dtype=np_default_dtype) return bias @@ -504,6 +504,7 @@ def _calculate_thresholds(self): if self._q_node.op_type == "Quant": bit_width = self._model.get_initializer(self._q_node.input[3]) narrow = q_inst.get_nodeattr("narrow") + signed = q_inst.get_nodeattr("signed") elif self._q_node.op_type == "BipolarQuant": bit_width = 1.0 else: @@ -533,6 +534,8 @@ def _calculate_thresholds(self): min_threshold = -half_step - step * ((num_thresholds // 2) - 1) if not narrow: min_threshold -= step + if not signed: + min_threshold = half_step for c in range(num_scale_channels): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t diff --git a/tests/brevitas/test_brevitas_quant_identity_export.py b/tests/brevitas/test_brevitas_quant_identity_export.py new file mode 100644 index 0000000000..c420a161e8 --- /dev/null +++ b/tests/brevitas/test_brevitas_quant_identity_export.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +import onnx # noqa +import os +import torch +from brevitas.export import export_qonnx +from brevitas.nn import QuantIdentity +from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Uint8ActPerTensorFloat +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.basic import get_preferred_onnx_opset +from qonnx.util.cleanup import cleanup as qonnx_cleanup + +import finn.core.onnx_exec as oxe +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN + + +@pytest.mark.brevitas_export +@pytest.mark.parametrize("abits", [2, 4, 8]) +@pytest.mark.parametrize("ishape", [(1, 15), (1, 32, 1, 1)]) +@pytest.mark.parametrize("narrow", [True, False]) +@pytest.mark.parametrize("quant", [Int8ActPerTensorFloat, Uint8ActPerTensorFloat]) +def test_brevitas_quant_identity_export(abits, ishape, narrow, quant): + export_path = f"test_brevitas_quant_identity_export_{abits}_{narrow}_{quant}.onnx" + b_act = QuantIdentity(act_quant=quant, bit_width=abits, narrow=narrow) + + export_qonnx( + b_act, + torch.randn(ishape), + export_path, + opset_version=get_preferred_onnx_opset(), + ) + qonnx_cleanup(export_path, out_file=export_path) + model = ModelWrapper(export_path) + model = model.transform(ConvertQONNXtoFINN()) + + inp_tensor = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(np.float32) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_act.eval() + expected = b_act.forward(inp_tensor).detach().numpy() + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_path)