-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_to_onnx.py
31 lines (24 loc) · 1.01 KB
/
export_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Description: Convert a TensorFlow SavedModel to ONNX
# Imports
import os
import tensorflow as tf
import tf2onnx
from utilities.tools import suppress_tf_warnings
# Supress TF warnings
suppress_tf_warnings()
# Function to convert a SavedModel to ONNX
def convert_saved_model_to_onnx(saved_model_path, onnx_output_path):
# Load the SavedModel
model = tf.keras.models.load_model(saved_model_path)
# Convert the TensorFlow model to ONNX
onnx_model, _ = tf2onnx.convert.from_keras(model)
# Save the ONNX model
if not os.path.exists(onnx_output_path):
print('Creating new file')
with open(onnx_output_path, 'w'): pass
with open(onnx_output_path, 'wb') as f:
f.write(onnx_model.SerializeToString())
if __name__ == '__main__':
saved_model_path = 'pre_filter/efficientnet-pre-filter-refactored-dataset-2_best_model.h5'
onnx_output_path = 'onnx/pre_filter/efficientnet-pre-filter-refactored-dataset-2.onnx'
convert_saved_model_to_onnx(saved_model_path, onnx_output_path)