Skip to content

Commit

Permalink
fix black format
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Apr 9, 2024
1 parent 543865f commit 318913d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"INT8": tf_trt.TrtPrecisionMode.INT8,
}


class ModelOptimizer:
"""
Class of converter for tensorrt
Expand All @@ -38,7 +39,7 @@ def calibration_input_fn():

def convert(self, precision="FP32"):
t0 = time.time()
print('Converting the model.')
print("Converting the model.")

if precision == "INT8" and self.calibration_data is None:
raise (Exception("No calibration data set!"))
Expand All @@ -55,30 +56,34 @@ def convert(self, precision="FP32"):
)

if precision == "INT8":
self.func = self.converter.convert(calibration_input_fn=self.calibration_data)
self.func = self.converter.convert(
calibration_input_fn=self.calibration_data
)
else:
self.func = self.converter.convert()
print('Done in %ds' % (time.time()-t0))
print("Done in %ds" % (time.time() - t0))

def build(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print('Building TRT engines for shape:', input_shape)
print("Building TRT engines for shape:", input_shape)

def input_fn():
x = np.random.random(input_shape).astype(np.float32)
x = tf.cast(x, tf.float32)
yield x

self.converter.build(input_fn)
print('Done in %ds' % (time.time()-t0))
print("Done in %ds" % (time.time() - t0))

def build_func(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print('Building TRT engines for shape:', input_shape)
print("Building TRT engines for shape:", input_shape)
x = np.random.random(input_shape)
x = tf.cast(x, tf.float32)
self.func(x)
print('Done in %ds' % (time.time()-t0))
print("Done in %ds" % (time.time() - t0))

def save(self, output_dir):
self.converter.save(output_saved_model_dir=output_dir)
Expand All @@ -89,18 +94,17 @@ def main():
description="Convert a seqnn model to TensorRT model."
)
parser.add_argument(
"-t",
"--targets_file",
default=None,
help="Path to the target variants file"
"-t", "--targets_file", default=None, help="Path to the target variants file"
)
parser.add_argument(
"-o",
"--out_dir",
default="trt_out",
help="Output directory for storing saved models (original & converted)",
)
parser.add_argument("params_file", type=str, help="Path to the JSON parameters file")
parser.add_argument(
"params_file", type=str, help="Path to the JSON parameters file"
)
parser.add_argument("model_file", help="Trained model HDF5.")
args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions src/baskerville/helpers/trt_optimized_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


class OptimizedModel:
"""
Class of model optimized with tensorrt
Expand Down

0 comments on commit 318913d

Please sign in to comment.