Skip to content

Commit

Permalink
fix ia3
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Jun 18, 2024
1 parent e1be2c8 commit cdf3634
Show file tree
Hide file tree
Showing 6 changed files with 577 additions and 144 deletions.
44 changes: 43 additions & 1 deletion src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
# transfer learning #
#####################
class IA3(tf.keras.layers.Layer):
# activation-rescale adapter:
# https://arxiv.org/pdf/2205.05638.pdf
# ia3 module for attention layer, scale output.

def __init__(self,
original_layer,
Expand Down Expand Up @@ -69,6 +69,48 @@ def get_config(self):
)
return config

class IA3_ff(tf.keras.layers.Layer):
# https://arxiv.org/pdf/2205.05638.pdf
# ia3 module for down-projection ff layer, scale input.

def __init__(self,
original_layer,
trainable=False,
**kwargs):

# keep the name of this layer the same as the original dense layer.
original_layer_config = original_layer.get_config()
name = original_layer_config["name"]
kwargs.pop("name", None)
super().__init__(name=name, trainable=trainable, **kwargs)

self.input_dim = original_layer.input_shape[-1]

self.original_layer = original_layer
self.original_layer.trainable = False

# IA3 weights. Make it a dense layer to control trainable
self._ia3_layer = tf.keras.layers.Dense(
units=self.input_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.Ones(),
trainable=True,
name="ia3_ff"
)

def call(self, inputs):
scaler = self._ia3_layer(tf.constant([[1]], dtype='float64'))[0]
return self.original_layer(inputs * scaler)

def get_config(self):
config = super().get_config().copy()
config.update(
{
"size": self.input_dim
}
)
return config

class Lora(tf.keras.layers.Layer):
# adapted from:
# https://arxiv.org/abs/2106.09685
Expand Down
18 changes: 13 additions & 5 deletions src/baskerville/scripts/hound_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ def main():
mode='full')

elif args.att_adapter=='ia3':
transfer_helper.add_ia3(seqnn_model.model)
seqnn_model.model = transfer_helper.add_ia3(seqnn_model.model, strand_pairs[0])

'''
# conv adapter
# assume seqnn_model is appropriately frozen
if args.conv_adapter is not None:
Expand Down Expand Up @@ -314,6 +315,7 @@ def main():
bottleneck_ratio=args.se_ratio,
insert_mode='pre_att',
unfreeze_bn=True)
'''

#################
# final summary #
Expand Down Expand Up @@ -368,10 +370,16 @@ def main():

# merge ia3 weights to original, save weight to: model_best_mergeweight.h5
if args.att_adapter=='ia3':
seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir)
transfer_helper.merge_ia3(seqnn_model.model)
seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir)
transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir)
# ia3 model
ia3_model = seqnn_model.model
ia3_model.load_weights('%s/model_best.h5'%args.out_dir)
# original model
seqnn_model2 = seqnn.SeqNN(params_model)
seqnn_model2.restore(args.restore, trunk=args.trunk)
original_model = seqnn_model2.model
# merge weights into original model
transfer_helper.merge_ia3(original_model, ia3_model)
original_model.save('%s/model_best.mergeW.h5'%args.out_dir)

else:
########################################
Expand Down
Loading

0 comments on commit cdf3634

Please sign in to comment.