Skip to content

Commit

Permalink
[InsertDWC] Preserve onnx tensor dtype when inserting DWCs
Browse files Browse the repository at this point in the history
  • Loading branch information
auphelia committed Sep 19, 2024
1 parent ec5613c commit fb60055
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/finn/transformation/fpgadataflow/insert_dwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from onnx import TensorProto
from onnx import helper as oh
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -110,12 +109,15 @@ def apply(self, model):
# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()

# determine dtype for dwc
# determine FINN dtype for dwc
dtype = n0.get_output_datatype()
# determine onnx tensor dtype for dwc
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type

dwc_output_tensor = oh.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
n0_tensor_dtype,
dwc_shape,
)
graph.value_info.append(dwc_output_tensor)
Expand Down

0 comments on commit fb60055

Please sign in to comment.