Skip to content

Commit

Permalink
Fix some Ruff rule B008 violations.
Browse files Browse the repository at this point in the history
For remaining B008 violations, see [Issue 6945](tensorflow#6945)
  • Loading branch information
smokestacklightnin committed Oct 27, 2024
1 parent 9724828 commit ad10616
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tfx/examples/bert/utils/bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Configurable fine-tuning BERT models for various tasks."""

from __future__ import annotations
from typing import Optional, List, Union

import tensorflow as tf
Expand Down Expand Up @@ -59,8 +60,7 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,

def compile_bert_classifier(
model: tf.keras.Model,
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
loss: tf.keras.losses.Loss | None = None,
learning_rate: float = 2e-5,
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
"""Compile the BERT classifier using suggested parameters.
Expand All @@ -79,6 +79,9 @@ def compile_bert_classifier(
Returns:
None.
"""
if loss is None:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

if metrics is None:
metrics = ["sparse_categorical_accuracy"]

Expand Down

0 comments on commit ad10616

Please sign in to comment.