-
Notifications
You must be signed in to change notification settings - Fork 46
/
train.py
executable file
·43 lines (35 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import argparse
import tensorboardX
import os
import random
import numpy as np
from utils import AverageMeter, calculate_accuracy
def train_epoch(model, data_loader, criterion, optimizer, epoch, log_interval, device):
model.train()
train_loss = 0.0
losses = AverageMeter()
accuracies = AverageMeter()
for batch_idx, (data, targets) in enumerate(data_loader):
data, targets = data.to(device), targets.to(device)
outputs = model(data)
loss = criterion(outputs, targets)
acc = calculate_accuracy(outputs, targets)
train_loss += loss.item()
losses.update(loss.item(), data.size(0))
accuracies.update(acc, data.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx + 1) % log_interval == 0:
avg_loss = train_loss / log_interval
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(data_loader.dataset), 100. * (batch_idx + 1) / len(data_loader), avg_loss))
train_loss = 0.0
print('Train set ({:d} samples): Average loss: {:.4f}\tAcc: {:.4f}%'.format(
len(data_loader.dataset), losses.avg, accuracies.avg * 100))
return losses.avg, accuracies.avg