Skip to content

Commit

Permalink
Merge pull request #1035 from Xilinx/bugfix/threshold_weight_padding
Browse files Browse the repository at this point in the history
[Threshold RTL] padd threshold steps based on activation bitwidth
  • Loading branch information
azizb-xlnx authored Apr 10, 2024
2 parents 2fffa76 + b02f72e commit e188b4c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
if weights.shape == (1, 1):
weights = np.broadcast_to(weights, expected_shape)

width_padded = roundup_to_integer_multiple(weights.shape[1], 4)
odt = self.get_output_datatype().bitwidth()
width_padded = roundup_to_integer_multiple(weights.shape[1], 2**odt)
weight_padded = np.zeros((weights.shape[0], width_padded))
weight_padded[: weights.shape[0], :n_thres_steps] = weights
weight_stream = []
Expand Down
3 changes: 1 addition & 2 deletions tests/end2end/test_end2end_bnn_pynq.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def fold_tfc(model):
inp_qnt_node = model.get_nodes_by_op_type("Thresholding_rtl")[0]
inp_qnt = getCustomOp(inp_qnt_node)
inp_qnt.set_nodeattr("PE", 49)
# TODO: update PYNQ driver to support runtime writeable weights for RTL Thresholding
# inp_qnt.set_nodeattr("runtime_writeable_weights", 1)
inp_qnt.set_nodeattr("runtime_writeable_weights", 1)
return model


Expand Down

0 comments on commit e188b4c

Please sign in to comment.