diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index bdbcc53d83..ab2280554c 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -121,6 +121,7 @@ ) from finn.transformation.streamline import Streamline from finn.transformation.streamline.reorder import MakeMaxPoolNHWC +from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds from finn.util.basic import ( get_rtlsim_trace_depth, pyverilate_get_liveness_threshold_cycles, @@ -503,6 +504,7 @@ def step_minimize_bit_width(model: ModelWrapper, cfg: DataflowBuildConfig): if cfg.minimize_bit_width: model = model.transform(MinimizeWeightBitWidth()) model = model.transform(MinimizeAccumulatorWidth()) + model = model.transform(RoundAndClipThresholds()) # make sure the changed datatypes are propagated through the network model = model.transform(InferDataTypes()) return model diff --git a/src/finn/transformation/streamline/round_thresholds.py b/src/finn/transformation/streamline/round_thresholds.py index 5ba5ee0ff5..312db404ac 100644 --- a/src/finn/transformation/streamline/round_thresholds.py +++ b/src/finn/transformation/streamline/round_thresholds.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020, Xilinx +# Copyright (c) 2020-2022, Xilinx +# Copyright (C) 2022-2024, Advanced Micro Devices, Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -27,42 +28,67 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.transformation.infer_datatypes import InferDataTypes class RoundAndClipThresholds(Transformation): """For MultiThreshold nodes operating on integer inputs, round up thresholds values to the nearest integer. Additionally, if the input - is unsigned, sets negative thresholds to zero.""" + is unsigned, sets negative thresholds to zero. Type-casts thresholds (back) + to the float32 container type (this is separate from the quantization + annotation). Runs InferDataTypes() afterward to propagate any changes to the + quantization data types.""" - def apply(self, model): + def apply(self, model: ModelWrapper): # noqa graph = model.graph graph_modified = False - for n in graph.node: - if n.op_type == "MultiThreshold": - idtype = model.get_tensor_datatype(n.input[0]) - T = model.get_initializer(n.input[1]) - Tnew = np.ceil(T) - if idtype.is_integer() and (T != Tnew).any(): - # round up the thresholds to nearest integer - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) + for index, node in enumerate(graph.node): + op_type = node.op_type + if op_type == "MultiThreshold" or op_type.startswith("Thresholding"): + thresholds = model.get_initializer(node.input[1]) + if thresholds is None: + continue + dtype = model.get_tensor_datatype(node.input[0]) + # This transformation only applies to thresholding operations + # operating on integer inputs + if not dtype.is_integer(): + continue + # Round thresholds up to nearest integer and clip thresholds + # outside the input range + # Note: This might promote the thresholds to float64 and + # introduce extra inaccuracies due to large integers not being + # exactly representable in floating-point representation. + # See for example: np.ceil(np.float32(16777217)) == 16777216 + new_thresholds = np.clip(np.ceil(thresholds), dtype.min(), dtype.max() + 1) + # Convert back to the preferred float32 container type + new_thresholds = new_thresholds.astype(np.float32) + # Insert the rounded and clipped thresholds back into the model + model.set_initializer(node.input[1], new_thresholds) + # The rounded and clipped thresholds now fit into a data type + # that is one bit bigger than the input datatype + # Determine new max_value + max_val = dtype.max() + 1 + if not dtype.signed(): + tdt = DataType.get_smallest_possible(max_val) + else: + tdt = DataType.get_smallest_possible(-(max_val) - 1) + model.set_tensor_datatype(node.input[1], tdt) + # If hw op we need to set the weight data type attribute as well + if op_type.startswith("Thresholding"): + inst = getCustomOp(node) + inst.set_nodeattr("weightDataType", tdt.name) + # ones + if np.any(new_thresholds != thresholds): + # Track the graph has been modified to inform the transform + # container to exhaustively repeat this transformation until + # no changes are possible graph_modified = True - if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any(): - # clip any negative thresholds if input is unsigned - Tnew = np.clip(Tnew, 0, None) - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) - graph_modified = True - if idtype.is_integer() and ( - (Tnew < (idtype.min() - 1)).any() or (Tnew > (idtype.max() + 1)).any() - ): - # clip any large thresholds to input range + 1 - Tnew = np.clip(Tnew, idtype.min() - 1, idtype.max() + 1) - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) - graph_modified = True - return (model, graph_modified) + # Immediately exit here to propagate the data type changes + # before considering the next node + break + model = model.transform(InferDataTypes()) + return model, graph_modified diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py index 81c6316ec1..0d3418624a 100644 --- a/tests/end2end/test_end2end_bnn_pynq.py +++ b/tests/end2end/test_end2end_bnn_pynq.py @@ -94,6 +94,7 @@ MakeMaxPoolNHWC, MoveScalarLinearPastInvariants, ) +from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds from finn.util.basic import get_finn_root, make_build_dir, test_board_map from finn.util.pytorch import ToTensor from finn.util.test import ( @@ -672,6 +673,7 @@ def test_minimize_bit_width(self, topology, wbits, abits, board): model = load_test_checkpoint_or_skip(prev_chkpt_name) model = model.transform(MinimizeAccumulatorWidth()) model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(RoundAndClipThresholds()) curr_chkpt_name = get_checkpoint_name(topology, wbits, abits, "minimize_bit_width") model.save(curr_chkpt_name) diff --git a/tests/end2end/test_end2end_mobilenet_v1.py b/tests/end2end/test_end2end_mobilenet_v1.py index 01d995c147..4c52277970 100644 --- a/tests/end2end/test_end2end_mobilenet_v1.py +++ b/tests/end2end/test_end2end_mobilenet_v1.py @@ -353,6 +353,7 @@ def test_end2end_mobilenet_minimize_bit_width(): model = load_test_checkpoint_or_skip(build_dir + "/end2end_mobilenet_folded.onnx") model = model.transform(MinimizeAccumulatorWidth()) model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(RoundAndClipThresholds()) model.save(build_dir + "/end2end_mobilenet_minimize_bitwidth.onnx") diff --git a/tests/fpgadataflow/test_fpgadataflow_thresholding.py b/tests/fpgadataflow/test_fpgadataflow_thresholding.py index fe7ba3d9fb..2079fe7fc5 100644 --- a/tests/fpgadataflow/test_fpgadataflow_thresholding.py +++ b/tests/fpgadataflow/test_fpgadataflow_thresholding.py @@ -49,6 +49,7 @@ from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds test_fpga_part = "xczu3eg-sbva484-1-e" target_clk_ns = 5 @@ -133,10 +134,8 @@ def make_single_multithresholding_modelwrapper( @pytest.mark.parametrize( "idt_tdt_cfg", [ - (DataType["INT8"], DataType["INT8"]), - (DataType["INT8"], DataType["INT9"]), - (DataType["UINT5"], DataType["UINT5"]), - (DataType["UINT5"], DataType["UINT6"]), + (DataType["INT8"], DataType["INT25"]), + (DataType["UINT5"], DataType["UINT8"]), ], ) @pytest.mark.parametrize("fold", [-1, 1, 2]) @@ -145,6 +144,7 @@ def make_single_multithresholding_modelwrapper( @pytest.mark.parametrize("impl_style", ["hls", "rtl"]) @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.parametrize("mem_mode", ["internal_embedded", "internal_decoupled"]) +@pytest.mark.parametrize("round_thresh", [True, False]) @pytest.mark.fpgadataflow @pytest.mark.vivado @pytest.mark.slow @@ -159,6 +159,7 @@ def test_fpgadataflow_thresholding( impl_style, exec_mode, mem_mode, + round_thresh, ): # the mem_mode parameter can only be used for the hls thresholding # so the test will only be executed once for impl_style=rtl and once skipped @@ -234,6 +235,8 @@ def test_fpgadataflow_thresholding( node = model.get_nodes_by_op_type(model.graph.node[0].op_type)[0] inst = getCustomOp(node) inst.set_nodeattr("PE", pe) + if round_thresh is True: + model = model.transform(RoundAndClipThresholds()) model = model.transform(GiveUniqueNodeNames()) if impl_style == "hls": diff --git a/tests/transformation/streamline/test_round_thresholds.py b/tests/transformation/streamline/test_round_thresholds.py index 85c60b37d5..6de82e6750 100644 --- a/tests/transformation/streamline/test_round_thresholds.py +++ b/tests/transformation/streamline/test_round_thresholds.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020, Xilinx +# Copyright (c) 2020-2022, Xilinx, Inc. +# Copyright (C) 2022-2024, Advanced Micro Devices, Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -32,39 +33,238 @@ from onnx import TensorProto, helper from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import gen_finn_dt_tensor import finn.core.onnx_exec as oxe from finn.transformation.streamline import RoundAndClipThresholds +# Tests the RoundAndClipThresholds transformation under various input, output +# data type combinations with purely integer inputs. Without proper rounding, +# this tests only the clipping, range and type-casting behavior of the +# transformation. +@pytest.mark.parametrize( + "i_dtype", + [ + # Explanation for selecting these test configurations: + # 1. Below 24-bit thresholds we will not observe any interesting rounding + # behavior, as all integers < 2^24 can be exactly represented in 32-bit + # floating-point. Thus, we test thresholds at 25-bit signed integers and + # generate test inputs slightly above and below this. + # 2. We want to test out-of-range clipping of thresholds, in particular + # clipping of the negative portion of signed thresholds. Thus, we only + # generate signed thresholds, but test with signed and unsigned + # inputs of smaller, larger and equal range. + # 3. Testing proper floating-point thresholds requires a separate test-case + "INT23", + "UINT23", + "INT24", + "UINT24", + "INT25", + "UINT25", + "INT26", + "UINT26", + ], +) +@pytest.mark.parametrize( + "o_dtype", + [ + # Explanation for selecting these test configurations: + # 1. Outputs of MultiThreshold are typically much smaller bit-width than the + # inputs and thresholds. + # 2. However, with randomly samples thresholds from a rather large range due + # to the selected input bit-widths (see above), we risk not adequately + # covering the input range if we sample too few thresholds. The number of + # thresholds sampled depends on the bit-width of the output, thus we use + # rather high bit-width for testing. + # 3. For a "real" model, the quantization procedure *should* take care of + # adequately covering the true input range. + "INT8", + "UINT8", + ], +) +@pytest.mark.parametrize( + "n_elems", + [ + # Explanation for selecting these test configurations: + # 1. Small edge cases and quickly running through tests: 1, 2, 3, 4 + # 2. Large test case 256, hopefully amplifying any rarely occurring errors + 1, + 2, + 3, + 4, + 256, + ], +) @pytest.mark.streamline -def test_round_thresholds(): - v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4]) - thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1]) - out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) - node_def = helper.make_node( - "MultiThreshold", ["v", "thresholds"], ["out"], domain="qonnx.custom_op.general" +def test_round_and_clip_thresholds_ints(i_dtype, o_dtype, n_elems): + i_dtype = DataType[i_dtype] + t_dtype = DataType["INT25"] # Note: Matches configuration above + o_dtype = DataType[o_dtype] # noqa: Duplicate model setup code + node = helper.make_node( + "MultiThreshold", + domain="qonnx.custom_op.general", + inputs=["inp", "thresholds"], + outputs=["out"], + out_dtype=str(o_dtype), + out_bias=float(o_dtype.min()), ) - graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - model_def = qonnx_make_model(graph_def) - model = ModelWrapper(model_def) - threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32) - model.set_initializer("thresholds", threshold_val) - model.set_tensor_datatype("v", DataType["INT8"]) - inp_dict_f = {"v": np.floor(threshold_val).T} - inp_dict_n = {"v": np.round(threshold_val).T} - inp_dict_c = {"v": np.ceil(threshold_val).T} - orig_f = oxe.execute_onnx(model, inp_dict_f)["out"] - orig_n = oxe.execute_onnx(model, inp_dict_n)["out"] - orig_c = oxe.execute_onnx(model, inp_dict_c)["out"] - assert model.get_tensor_datatype("thresholds") == DataType["FLOAT32"] - new_model = model.transform(RoundAndClipThresholds()) - # rounded up thresholds should have same dtype as input - assert new_model.get_tensor_datatype("thresholds") == DataType["INT8"] - new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"] - new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"] - new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"] - assert np.isclose(orig_f, new_f, atol=1e-3).all() - assert np.isclose(orig_n, new_n, atol=1e-3).all() - assert np.isclose(orig_c, new_c, atol=1e-3).all() + n_thresholds = o_dtype.get_num_possible_values() - 1 + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, n_elems]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, n_elems]) + thresholds = helper.make_tensor_value_info( + "thresholds", TensorProto.FLOAT, [n_elems, n_thresholds] + ) + graph = helper.make_graph([node], "thresholds", [inp, thresholds], [out]) + model = ModelWrapper(helper.make_model(graph)) + + inp = gen_finn_dt_tensor(i_dtype, [1, n_elems]) + inp[0][0] = i_dtype.max() + thresholds = np.sort(gen_finn_dt_tensor(t_dtype, [n_elems, n_thresholds])) + model.set_tensor_datatype("inp", i_dtype) # noqa: Duplicate model execution + model.set_tensor_datatype("thresholds", t_dtype) + model.set_tensor_datatype("out", o_dtype) + model.set_initializer("thresholds", thresholds) + + # Execute the model before running the RoundAndClipThresholds transformation + out_expected = oxe.execute_onnx(model, {"inp": inp})["out"] + assert model.get_tensor_datatype("thresholds") == t_dtype + + model = model.transform(RoundAndClipThresholds()) + + # After this transformation, the thresholds and output data type should be + # inferred correctly + if not i_dtype.signed(): + new_tdt = DataType.get_smallest_possible(i_dtype.max() + 1) + else: + new_tdt = DataType.get_smallest_possible(-(i_dtype.max() + 1) - 1) + assert model.get_tensor_datatype("thresholds") == new_tdt + assert model.get_tensor_datatype("out") == o_dtype + + # After this transformation, the container type used to store the thresholds + # values must be float32. No other type-cast or type promotion may happen. + assert model.get_initializer("thresholds").dtype == np.float32 + + # After rounding, all thresholds must be integers represented as float32 + assert all(x.is_integer() for x in model.get_initializer("thresholds").flatten()) + + # Execute the model after running the RoundAndClipThresholds transformation + out_produced = oxe.execute_onnx(model, {"inp": inp})["out"] + + assert np.all(out_produced == out_expected) + + +# Tests the RoundAndClipThresholds transformation under various input, output +# data type combinations with purely integer inputs. This test case tests actual +# rounding of floating-point thresholds. +@pytest.mark.parametrize( + "i_dtype", + [ + # Explanation for selecting these test configurations: + # 1. Below 24-bit thresholds we will not observe any interesting rounding + # behavior, as all integers < 2^24 can be exactly represented in 32-bit + # floating-point. Thus, we test thresholds at 25-bit signed integers and + # generate test inputs slightly above and below this. + # 2. We want to test out-of-range clipping of thresholds, in particular + # clipping of the negative portion of signed thresholds. Thus, we only + # generate signed thresholds, but test with signed and unsigned + # inputs of smaller, larger and equal range. + # 3. Testing proper floating-point thresholds requires a separate test-case + "INT23", + "UINT23", + "INT24", + "UINT24", + "INT25", + "UINT25", + "INT26", + "UINT26", + ], +) +@pytest.mark.parametrize( + "o_dtype", + [ + # Explanation for selecting these test configurations: + # 1. Outputs of MultiThreshold are typically much smaller bit-width than the + # inputs and thresholds. + # 2. However, with randomly samples thresholds from a rather large range due + # to the selected input bit-widths (see above), we risk not adequately + # covering the input range if we sample too few thresholds. The number of + # thresholds sampled depends on the bit-width of the output, thus we use + # rather high bit-width for testing. + # 3. For a "real" model, the quantization procedure *should* take care of + # adequately covering the true input range. + "INT8", + "UINT8", + ], +) +@pytest.mark.parametrize( + "n_elems", + [ + # Explanation for selecting these test configurations: + # 1. Small edge cases and quickly running through tests: 1, 2, 3, 4 + # 2. Large test case 256, hopefully amplifying any rarely occurring errors + 1, + 2, + 3, + 4, + 256, + ], +) +@pytest.mark.streamline +def test_round_and_clip_thresholds_floats(i_dtype, o_dtype, n_elems): + i_dtype = DataType[i_dtype] + t_dtype = DataType["FLOAT32"] + o_dtype = DataType[o_dtype] # noqa: Duplicate model setup code + node = helper.make_node( + "MultiThreshold", + domain="qonnx.custom_op.general", + inputs=["inp", "thresholds"], + outputs=["out"], + out_dtype=str(o_dtype), + ) + n_thresholds = o_dtype.get_num_possible_values() - 1 + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, n_elems]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, n_elems]) + thresholds = helper.make_tensor_value_info( + "thresholds", TensorProto.FLOAT, [n_elems, n_thresholds] + ) + graph = helper.make_graph([node], "thresholds", [inp, thresholds], [out]) + model = ModelWrapper(helper.make_model(graph)) + + inp = gen_finn_dt_tensor(i_dtype, [1, n_elems]) + # Draw uniformly random prototype thresholds in [0,+1] range + thresholds = np.random.rand(n_elems, n_thresholds) + # Type alias to 25-bit signed integer type used to set the range of the + # thresholds + INT25 = DataType["INT25"] # noqa: Variable name not lowercase + # Map the prototype thresholds into the test integer range and sort + thresholds = np.sort((INT25.max() - INT25.min()) * thresholds + INT25.min()) + # Set data type annotations for the input and thresholds tensor + model.set_tensor_datatype("inp", i_dtype) # noqa: Duplicate model execution + model.set_tensor_datatype("thresholds", t_dtype) + model.set_tensor_datatype("out", o_dtype) + model.set_initializer("thresholds", thresholds) + + # Execute the model before running the RoundAndClipThresholds transformation + out_expected = oxe.execute_onnx(model, {"inp": inp})["out"] + # Before rounding the threshold data type must be as annotated + assert model.get_tensor_datatype("thresholds") == t_dtype + + model = model.transform(RoundAndClipThresholds()) + + if not i_dtype.signed(): + new_tdt = DataType.get_smallest_possible(i_dtype.max() + 1) + else: + new_tdt = DataType.get_smallest_possible(-(i_dtype.max() + 1) - 1) + assert model.get_tensor_datatype("thresholds") == new_tdt + assert model.get_tensor_datatype("out") == o_dtype + + # After this transformation, the container type used to store the thresholds + # values must be float32. No other type-cast or type promotion may happen. + assert model.get_initializer("thresholds").dtype == np.float32 + # After rounding, all thresholds must be integers represented as float32 + assert all(x.is_integer() for x in model.get_initializer("thresholds").flatten()) + + out_produced = oxe.execute_onnx(model, {"inp": inp})["out"] + + assert np.allclose(out_produced, out_expected, atol=1.0e-3)