diff --git a/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb b/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb index 69ba34b4..3d447ae8 100644 --- a/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb +++ b/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb @@ -1027,7 +1027,6 @@ "from jax.sharding import NamedSharding\n", "from jax.sharding import PartitionSpec as PS\n", "\n", - "\n", "def train(trainer, train_dataloader, eval_dataloader):\n", " total_training_time = 0\n", " total_steps = 0\n", @@ -1065,7 +1064,7 @@ " to_log = {}\n", " to_log.update(\n", " {\n", - " \"train_step\": step,\n", + " \"train_step\": total_steps,\n", " \"train/loss\": metrics[\"loss\"],\n", " \"train/accuracy\": metrics[\"accuracy\"],\n", " \"train/step_time\": step_duration,\n",