diff --git a/tfx/examples/bert/utils/bert_models.py b/tfx/examples/bert/utils/bert_models.py index d67fa1c6b0..cf59e9ac41 100644 --- a/tfx/examples/bert/utils/bert_models.py +++ b/tfx/examples/bert/utils/bert_models.py @@ -13,6 +13,7 @@ # limitations under the License. """Configurable fine-tuning BERT models for various tasks.""" +from __future__ import annotations from typing import Optional, List, Union import tensorflow as tf @@ -59,8 +60,7 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer, def compile_bert_classifier( model: tf.keras.Model, - loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True), + loss: tf.keras.losses.Loss | None = None, learning_rate: float = 2e-5, metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None): """Compile the BERT classifier using suggested parameters. @@ -79,6 +79,9 @@ def compile_bert_classifier( Returns: None. """ + if loss is None: + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if metrics is None: metrics = ["sparse_categorical_accuracy"]