diff --git a/src/baskerville/helpers/tensorrt_helpers.py b/src/baskerville/helpers/tensorrt_helpers.py index 504c2f3..633a50f 100644 --- a/src/baskerville/helpers/tensorrt_helpers.py +++ b/src/baskerville/helpers/tensorrt_helpers.py @@ -1,11 +1,14 @@ -from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt -import tensorflow as tf -import tensorrt as trt import argparse import json +import pdb +import time + import numpy as np import pandas as pd -from baskerville import seqnn, dataset +import tensorflow as tf +from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt + +from baskerville import seqnn precision_dict = { @@ -14,8 +17,6 @@ "INT8": tf_trt.TrtPrecisionMode.INT8, } -# For TF-TRT: - class ModelOptimizer: """ @@ -27,8 +28,6 @@ class ModelOptimizer: def __init__(self, input_saved_model_dir, calibration_data=None): self.input_saved_model_dir = input_saved_model_dir self.calibration_data = None - self.loaded_model = None - if not calibration_data is None: self.set_calibration_data(calibration_data) @@ -38,81 +37,101 @@ def calibration_input_fn(): self.calibration_data = calibration_input_fn - def convert( - self, - output_saved_model_dir, - precision="FP32", - max_workspace_size_bytes=8000000000, - **kwargs, - ): + def convert(self, precision="FP32"): + t0 = time.time() + print("Converting the model.") + if precision == "INT8" and self.calibration_data is None: raise (Exception("No calibration data set!")) trt_precision = precision_dict[precision] conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( precision_mode=trt_precision, - max_workspace_size_bytes=max_workspace_size_bytes, use_calibration=precision == "INT8", + max_workspace_size_bytes=8000000000, ) - converter = tf_trt.TrtGraphConverterV2( + self.converter = tf_trt.TrtGraphConverterV2( input_saved_model_dir=self.input_saved_model_dir, conversion_params=conversion_params, ) if precision == "INT8": - converter.convert(calibration_input_fn=self.calibration_data) + self.func = self.converter.convert( + calibration_input_fn=self.calibration_data + ) else: - converter.convert() + self.func = self.converter.convert() + print("Done in %ds" % (time.time() - t0)) - converter.save(output_saved_model_dir=output_saved_model_dir) + def build(self, seq_length): + input_shape = (1, seq_length, 4) + t0 = time.time() + print("Building TRT engines for shape:", input_shape) - return output_saved_model_dir + def input_fn(): + x = np.random.random(input_shape).astype(np.float32) + x = tf.cast(x, tf.float32) + yield x - def predict(self, input_data): - if self.loaded_model is None: - self.load_default_model() + self.converter.build(input_fn) + print("Done in %ds" % (time.time() - t0)) - return self.loaded_model.predict(input_data) + def build_func(self, seq_length): + input_shape = (1, seq_length, 4) + t0 = time.time() + print("Building TRT engines for shape:", input_shape) + x = np.random.random(input_shape) + x = tf.cast(x, tf.float32) + self.func(x) + print("Done in %ds" % (time.time() - t0)) - def load_default_model(self): - self.loaded_model = tf.keras.models.load_model("resnet50_saved_model") + def save(self, output_dir): + self.converter.save(output_saved_model_dir=output_dir) def main(): parser = argparse.ArgumentParser( description="Convert a seqnn model to TensorRT model." ) - parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)") - parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file") parser.add_argument( - "targets_file", type=str, help="Path to the target variants file" + "-t", "--targets_file", default=None, help="Path to the target variants file" ) parser.add_argument( - "output_dir", - type=str, + "-o", + "--out_dir", + default="trt_out", help="Output directory for storing saved models (original & converted)", ) + parser.add_argument( + "params_file", type=str, help="Path to the JSON parameters file" + ) + parser.add_argument("model_file", help="Trained model HDF5.") args = parser.parse_args() - # Load target variants - targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0) - # Load parameters - with open(args.params_fn) as params_open: + with open(args.params_file) as params_open: params = json.load(params_open) - params_model = params["model"] + # Load keras model into seqnn class - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(args.model_fn) - seqnn_model.build_slice(np.array(targets_df.index)) - # seqnn_model.build_ensemble(True) + seqnn_model = seqnn.SeqNN(params["model"]) + seqnn_model.restore(args.model_file) + + # Load target variants + if args.targets_file is not None: + targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0) + seqnn_model.build_slice(np.array(targets_df.index)) + + # ensemble rc + seqnn_model.build_ensemble(True) # save this model to a directory - seqnn_model.model.save(f"{args.output_dir}/original_model") + seqnn_model.model.save(f"{args.out_dir}/original") # Convert the model - opt_model = ModelOptimizer(f"{args.output_dir}/original_model") - opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32") + opt_model = ModelOptimizer(f"{args.out_dir}/original") + opt_model.convert(precision="FP32") + # opt_model.build(seqnn_model.seq_length) + opt_model.save(f"{args.out_dir}/convert") if __name__ == "__main__": diff --git a/src/baskerville/helpers/trt_optimized_model.py b/src/baskerville/helpers/trt_optimized_model.py index 2032b72..ad90325 100644 --- a/src/baskerville/helpers/trt_optimized_model.py +++ b/src/baskerville/helpers/trt_optimized_model.py @@ -1,6 +1,5 @@ import tensorflow as tf from tensorflow.python.saved_model import tag_constants -from baskerville import layers class OptimizedModel: @@ -19,7 +18,6 @@ def __init__(self, saved_model_dir=None, strand_pair=[]): def predict(self, input_data): if self.loaded_model_fn is None: raise (Exception("Haven't loaded a model")) - # x = tf.constant(input_data.astype("float32")) x = tf.cast(input_data, tf.float32) labeling = self.loaded_model_fn(x) try: @@ -43,17 +41,5 @@ def load_model(self, saved_model_dir): wrapper_fp32 = saved_model_loaded.signatures["serving_default"] self.loaded_model_fn = wrapper_fp32 - def __call__(self, input_data): - # need to do the prediction for ensemble model here - x = tf.cast(input_data, tf.float32) - sequences_rev = layers.EnsembleReverseComplement()([x]) - if len(self.strand_pair) == 0: - strand_pair = None - else: - strand_pair = self.strand_pair[0] - preds = [ - layers.SwitchReverse(strand_pair)([self.predict(seq), rp]) - for (seq, rp) in sequences_rev - ] - preds_avg = tf.keras.layers.Average()(preds) - return preds_avg + def __call__(self, x): + return self.predict(x)