diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 66bfb4b..4e69f18 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -68,7 +68,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # load model if options.tensorrt: - seqnn_model.ensemble = OptimizedModel(model_file, params_model["strand_pair"]) + seqnn_model.ensemble = OptimizedModel(model_file, seqnn_model.strand_pair) input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list()) else: seqnn_model.restore(model_file)