Skip to content

Commit

Permalink
add --f16 for float16 inference
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Apr 24, 2024
1 parent 9fbfad3 commit 21271c9
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/baskerville/HY_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pyBigWig



def make_seq_1hot(genome_open, chrm, start, end, seq_len):
if start < 0:
seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end)
Expand All @@ -18,7 +17,7 @@ def make_seq_1hot(genome_open, chrm, start, end, seq_len):
seq_1hot = dna_io.dna_1hot(seq_dna)
return seq_1hot

#Helper function to get (padded) one-hot
# Helper function to get (padded) one-hot
def process_sequence(fasta_file, chrom, start, end, seq_len=524288) :

fasta_open = pysam.Fastafile(fasta_file)
Expand Down
23 changes: 21 additions & 2 deletions src/baskerville/scripts/borzoi_test_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from qnorm import quantile_normalize
from scipy.stats import pearsonr
from sklearn.metrics import explained_variance_score
from tensorflow.keras import mixed_precision

from baskerville import pygene
from baskerville import dataset
Expand Down Expand Up @@ -77,6 +78,13 @@ def main():
action="store_true",
help="Aggregate entire gene span [Default: %default]",
)
parser.add_option(
"--f16",
dest="f16",
default=False,
action="store_true",
help="use mixed precision for inference",
)
parser.add_option(
"-t",
dest="targets_file",
Expand Down Expand Up @@ -155,8 +163,19 @@ def main():
)

# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, options.head_i)
###################
# mixed precision #
###################
if options.f16:
mixed_precision.set_global_policy('mixed_float16') # first set global policy
seqnn_model = seqnn.SeqNN(params_model) # then create model
seqnn_model.restore(model_file, options.head_i)
seqnn_model.append_activation() # add additional activation to cast float16 output to float32
else:
# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, options.head_i)

seqnn_model.build_slice(targets_df.index)
seqnn_model.build_ensemble(options.rc, options.shifts)

Expand Down
23 changes: 20 additions & 3 deletions src/baskerville/scripts/hound_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from scipy.stats import spearmanr
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras import mixed_precision

from baskerville import bed
from baskerville import dataset
Expand Down Expand Up @@ -85,6 +86,12 @@ def main():
type=int,
help="Step across positions [Default: %(default)s]",
)
parser.add_argument(
"--f16",
default=False,
action="store_true",
help="use mixed precision for inference",
)
parser.add_argument(
"-t",
"--targets_file",
Expand Down Expand Up @@ -140,9 +147,19 @@ def main():
tfr_pattern=args.tfr_pattern,
)

# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_file, args.head_i)
###################
# mixed precision #
###################
if args.f16:
mixed_precision.set_global_policy('mixed_float16') # first set global policy
seqnn_model = seqnn.SeqNN(params_model) # then create model
seqnn_model.restore(args.model_file, args.head_i)
seqnn_model.append_activation() # add additional activation to cast float16 output to float32
else:
# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(args.model_file, args.head_i)

seqnn_model.build_ensemble(args.rc, args.shifts)

#######################################################
Expand Down
23 changes: 21 additions & 2 deletions src/baskerville/scripts/hound_eval_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from qnorm import quantile_normalize
from scipy.stats import pearsonr
import tensorflow as tf
from tensorflow.keras import mixed_precision

from baskerville import dataset
from baskerville import seqnn
Expand Down Expand Up @@ -74,6 +75,13 @@ def main():
type="int",
help="Step across positions [Default: %default]",
)
parser.add_option(
"--f16",
dest="f16",
default=False,
action="store_true",
help="use mixed precision for inference",
)
parser.add_option(
"--save",
dest="save",
Expand Down Expand Up @@ -190,8 +198,19 @@ def main():
)

# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, options.head_i)
###################
# mixed precision #
###################
if options.f16:
mixed_precision.set_global_policy('mixed_float16') # set global policy
seqnn_model = seqnn.SeqNN(params_model) # create model
seqnn_model.restore(model_file, options.head_i)
seqnn_model.append_activation() # add additional activation to cast float16 output to float32
else:
# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, options.head_i)

seqnn_model.build_slice(targets_df.index)
if options.step > 1:
seqnn_model.step(options.step)
Expand Down
6 changes: 6 additions & 0 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def build_embed(self, conv_layer_i: int, batch_norm: bool = True):
inputs=self.model.inputs, outputs=conv_layer.output
)

def append_activation(self):
"""add additional activation to convert float16 output to float32, required for mixed precision"""
model_0 = self.model
new_outputs = tf.keras.layers.Activation('linear', dtype='float32')(model_0.layers[-1].output)
self.model = tf.keras.Model(inputs=model_0.layers[0].input, outputs=new_outputs)

def build_ensemble(self, ensemble_rc: bool = False, ensemble_shifts=[0]):
"""Build ensemble of models computing on augmented input sequences."""
shift_bool = len(ensemble_shifts) > 1 or ensemble_shifts[0] != 0
Expand Down

0 comments on commit 21271c9

Please sign in to comment.