From bb34f6df7f69602aafe4b25d8eea5452feea053d Mon Sep 17 00:00:00 2001 From: wylerz Date: Fri, 18 Oct 2024 11:58:34 -0600 Subject: [PATCH] minor fix to train step logic --- colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb b/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb index 69ba34b4..3d447ae8 100644 --- a/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb +++ b/colabs/system_metrics/WandB_Llama3_1_Training_Colab_TPU.ipynb @@ -1027,7 +1027,6 @@ "from jax.sharding import NamedSharding\n", "from jax.sharding import PartitionSpec as PS\n", "\n", - "\n", "def train(trainer, train_dataloader, eval_dataloader):\n", " total_training_time = 0\n", " total_steps = 0\n", @@ -1065,7 +1064,7 @@ " to_log = {}\n", " to_log.update(\n", " {\n", - " \"train_step\": step,\n", + " \"train_step\": total_steps,\n", " \"train/loss\": metrics[\"loss\"],\n", " \"train/accuracy\": metrics[\"accuracy\"],\n", " \"train/step_time\": step_duration,\n",