Skip to content

Commit

Permalink
fix: pointnet example
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Aug 15, 2023
1 parent 7f540af commit a4c0345
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(train_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
Expand Down Expand Up @@ -394,8 +395,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(val_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
" \n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
Expand Down
6 changes: 4 additions & 2 deletions colabs/pyg/pointnet-classification/03_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(train_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
Expand Down Expand Up @@ -314,8 +315,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(val_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
"\n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
Expand Down

0 comments on commit a4c0345

Please sign in to comment.