diff --git a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb index 26b36c74..12eee4b8 100644 --- a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb +++ b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb @@ -363,8 +363,9 @@ " range(num_train_examples),\n", " desc=f\"Training Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(train_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(train_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", " \n", " optimizer.zero_grad()\n", " prediction = model(data)\n", @@ -394,8 +395,9 @@ " range(num_val_examples),\n", " desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(val_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(val_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", " \n", " with torch.no_grad():\n", " prediction = model(data)\n", diff --git a/colabs/pyg/pointnet-classification/03_sweep.ipynb b/colabs/pyg/pointnet-classification/03_sweep.ipynb index 09845a97..d029f42d 100644 --- a/colabs/pyg/pointnet-classification/03_sweep.ipynb +++ b/colabs/pyg/pointnet-classification/03_sweep.ipynb @@ -285,8 +285,9 @@ " range(num_train_examples),\n", " desc=f\"Training Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(train_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(train_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", "\n", " optimizer.zero_grad()\n", " prediction = model(data)\n", @@ -314,8 +315,9 @@ " range(num_val_examples),\n", " desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(val_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(val_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", "\n", " with torch.no_grad():\n", " prediction = model(data)\n",