Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

readd changes for ensemble before optimizing #25

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 63 additions & 44 deletions src/baskerville/helpers/tensorrt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt
import tensorflow as tf
import tensorrt as trt
import argparse
import json
import pdb
import time

import numpy as np
import pandas as pd
from baskerville import seqnn, dataset
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt

from baskerville import seqnn


precision_dict = {
Expand All @@ -14,8 +17,6 @@
"INT8": tf_trt.TrtPrecisionMode.INT8,
}

# For TF-TRT:


class ModelOptimizer:
"""
Expand All @@ -27,8 +28,6 @@ class ModelOptimizer:
def __init__(self, input_saved_model_dir, calibration_data=None):
self.input_saved_model_dir = input_saved_model_dir
self.calibration_data = None
self.loaded_model = None

if not calibration_data is None:
self.set_calibration_data(calibration_data)

Expand All @@ -38,81 +37,101 @@ def calibration_input_fn():

self.calibration_data = calibration_input_fn

def convert(
self,
output_saved_model_dir,
precision="FP32",
max_workspace_size_bytes=8000000000,
**kwargs,
):
def convert(self, precision="FP32"):
t0 = time.time()
print("Converting the model.")

if precision == "INT8" and self.calibration_data is None:
raise (Exception("No calibration data set!"))

trt_precision = precision_dict[precision]
conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_precision,
max_workspace_size_bytes=max_workspace_size_bytes,
use_calibration=precision == "INT8",
max_workspace_size_bytes=8000000000,
)
converter = tf_trt.TrtGraphConverterV2(
self.converter = tf_trt.TrtGraphConverterV2(
input_saved_model_dir=self.input_saved_model_dir,
conversion_params=conversion_params,
)

if precision == "INT8":
converter.convert(calibration_input_fn=self.calibration_data)
self.func = self.converter.convert(
calibration_input_fn=self.calibration_data
)
else:
converter.convert()
self.func = self.converter.convert()
print("Done in %ds" % (time.time() - t0))

converter.save(output_saved_model_dir=output_saved_model_dir)
def build(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)

return output_saved_model_dir
def input_fn():
x = np.random.random(input_shape).astype(np.float32)
x = tf.cast(x, tf.float32)
yield x

def predict(self, input_data):
if self.loaded_model is None:
self.load_default_model()
self.converter.build(input_fn)
print("Done in %ds" % (time.time() - t0))

return self.loaded_model.predict(input_data)
def build_func(self, seq_length):
input_shape = (1, seq_length, 4)
t0 = time.time()
print("Building TRT engines for shape:", input_shape)
x = np.random.random(input_shape)
x = tf.cast(x, tf.float32)
self.func(x)
print("Done in %ds" % (time.time() - t0))

def load_default_model(self):
self.loaded_model = tf.keras.models.load_model("resnet50_saved_model")
def save(self, output_dir):
self.converter.save(output_saved_model_dir=output_dir)


def main():
parser = argparse.ArgumentParser(
description="Convert a seqnn model to TensorRT model."
)
parser.add_argument("model_fn", type=str, help="Path to the Keras model file (.h5)")
parser.add_argument("params_fn", type=str, help="Path to the JSON parameters file")
parser.add_argument(
"targets_file", type=str, help="Path to the target variants file"
"-t", "--targets_file", default=None, help="Path to the target variants file"
)
parser.add_argument(
"output_dir",
type=str,
"-o",
"--out_dir",
default="trt_out",
help="Output directory for storing saved models (original & converted)",
)
parser.add_argument(
"params_file", type=str, help="Path to the JSON parameters file"
)
parser.add_argument("model_file", help="Trained model HDF5.")
args = parser.parse_args()

# Load target variants
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)

# Load parameters
with open(args.params_fn) as params_open:
with open(args.params_file) as params_open:
params = json.load(params_open)
params_model = params["model"]

# Load keras model into seqnn class
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_fn)
seqnn_model.build_slice(np.array(targets_df.index))
# seqnn_model.build_ensemble(True)
seqnn_model = seqnn.SeqNN(params["model"])
seqnn_model.restore(args.model_file)

# Load target variants
if args.targets_file is not None:
targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0)
seqnn_model.build_slice(np.array(targets_df.index))

# ensemble rc
seqnn_model.build_ensemble(True)

# save this model to a directory
seqnn_model.model.save(f"{args.output_dir}/original_model")
seqnn_model.ensemble.save(f"{args.out_dir}/original")

# Convert the model
opt_model = ModelOptimizer(f"{args.output_dir}/original_model")
opt_model.convert(f"{args.output_dir}/model_FP32", precision="FP32")
opt_model = ModelOptimizer(f"{args.out_dir}/original")
opt_model.convert(precision="FP32")
# opt_model.build(seqnn_model.seq_length)
opt_model.save(f"{args.out_dir}/convert")


if __name__ == "__main__":
Expand Down
18 changes: 2 additions & 16 deletions src/baskerville/helpers/trt_optimized_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from baskerville import layers


class OptimizedModel:
Expand All @@ -19,7 +18,6 @@ def __init__(self, saved_model_dir=None, strand_pair=[]):
def predict(self, input_data):
if self.loaded_model_fn is None:
raise (Exception("Haven't loaded a model"))
# x = tf.constant(input_data.astype("float32"))
x = tf.cast(input_data, tf.float32)
labeling = self.loaded_model_fn(x)
try:
Expand All @@ -43,17 +41,5 @@ def load_model(self, saved_model_dir):
wrapper_fp32 = saved_model_loaded.signatures["serving_default"]
self.loaded_model_fn = wrapper_fp32

def __call__(self, input_data):
# need to do the prediction for ensemble model here
x = tf.cast(input_data, tf.float32)
sequences_rev = layers.EnsembleReverseComplement()([x])
if len(self.strand_pair) == 0:
strand_pair = None
else:
strand_pair = self.strand_pair[0]
preds = [
layers.SwitchReverse(strand_pair)([self.predict(seq), rp])
for (seq, rp) in sequences_rev
]
preds_avg = tf.keras.layers.Average()(preds)
return preds_avg
def __call__(self, x):
return self.predict(x)
Loading