From a4c034557bd9f5411b136e3ec7a824edc36e182c Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 15 Aug 2023 07:00:21 +0000 Subject: [PATCH] fix: pointnet example --- .../pyg/pointnet-classification/02_pointnet_plus_plus.ipynb | 6 ++++-- colabs/pyg/pointnet-classification/03_sweep.ipynb | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb index 26b36c74..12eee4b8 100644 --- a/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb +++ b/colabs/pyg/pointnet-classification/02_pointnet_plus_plus.ipynb @@ -363,8 +363,9 @@ " range(num_train_examples),\n", " desc=f\"Training Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(train_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(train_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", " \n", " optimizer.zero_grad()\n", " prediction = model(data)\n", @@ -394,8 +395,9 @@ " range(num_val_examples),\n", " desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(val_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(val_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", " \n", " with torch.no_grad():\n", " prediction = model(data)\n", diff --git a/colabs/pyg/pointnet-classification/03_sweep.ipynb b/colabs/pyg/pointnet-classification/03_sweep.ipynb index 09845a97..d029f42d 100644 --- a/colabs/pyg/pointnet-classification/03_sweep.ipynb +++ b/colabs/pyg/pointnet-classification/03_sweep.ipynb @@ -285,8 +285,9 @@ " range(num_train_examples),\n", " desc=f\"Training Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(train_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(train_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", "\n", " optimizer.zero_grad()\n", " prediction = model(data)\n", @@ -314,8 +315,9 @@ " range(num_val_examples),\n", " desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n", " )\n", + " data_iter = iter(val_loader)\n", " for batch_idx in progress_bar:\n", - " data = next(iter(val_loader)).to(device)\n", + " data = next(data_iter).to(device)\n", "\n", " with torch.no_grad():\n", " prediction = model(data)\n",