Skip to content

Commit

Permalink
Merge pull request #1192 from Xilinx/fix/preserve_onnx_dtype
Browse files Browse the repository at this point in the history
Preserve onnx tensor dtype when inserting FIFOs
  • Loading branch information
auphelia authored Sep 19, 2024
2 parents d575f4c + fb60055 commit 71b546b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
8 changes: 5 additions & 3 deletions src/finn/transformation/fpgadataflow/insert_dwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from onnx import TensorProto
from onnx import helper as oh
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -110,12 +109,15 @@ def apply(self, model):
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()

# determine dtype for dwc
# determine FINN dtype for dwc
dtype = n0.get_output_datatype()
# determine onnx tensor dtype for dwc
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type

dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
dwc_shape,
)
graph.value_info.append(dwc_output_tensor)
Expand Down
13 changes: 9 additions & 4 deletions src/finn/transformation/fpgadataflow/insert_fifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import numpy as np
import warnings
from onnx import TensorProto
from onnx import helper as oh
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -114,6 +113,8 @@ def apply(self, model):
# determine fifo node attributes
fld_shape = n0.get_folded_output_shape()
dtype = n0.get_output_datatype()
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type

# check if folded_shape of output of first node and
# input of the second node is equal
Expand Down Expand Up @@ -145,7 +146,7 @@ def apply(self, model):
# or unless create_shallow_fifos is specified
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_output_shape(),
)
graph.value_info.append(fifo_output_tensor)
Expand Down Expand Up @@ -196,13 +197,15 @@ def apply(self, model):
fld_shape = n0.get_folded_input_shape(inp_ind)
n_shape = n0.get_normal_input_shape(inp_ind)
dtype = n0.get_input_datatype(inp_ind)
n0_itensor = model.get_tensor_valueinfo(graph_in_name)
n0_tensor_dtype = n0_itensor.type.tensor_type.elem_type
fifo_depth = n0.get_nodeattr("inFIFODepths")[inp_ind]

if fifo_depth > 2 or self.create_shallow_fifos:
# create fifo node
fifo_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_input_shape(inp_ind),
)
graph.value_info.append(fifo_output_tensor)
Expand Down Expand Up @@ -256,13 +259,15 @@ def apply(self, model):
fld_shape = n0.get_folded_output_shape(out_ind)
n_shape = n0.get_normal_output_shape(out_ind)
dtype = n0.get_output_datatype(out_ind)
n0_otensor = model.get_tensor_valueinfo(graph_out_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type
fifo_depth = n0.get_nodeattr("outFIFODepths")[out_ind]

if fifo_depth > 2 or self.create_shallow_fifos:
# create fifo node
fifo_input_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
n0.get_normal_output_shape(),
)
graph.value_info.append(fifo_input_tensor)
Expand Down

0 comments on commit 71b546b

Please sign in to comment.