-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
107 lines (88 loc) · 3.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from utils import save_checkpoint, load_checkpoint, print_examples
from get_loader import get_loader
from model import CNNtoRNN
def train():
transform = transforms.Compose(
[
# transforms.Resize((356, 356)),
transforms.Resize((299, 299)),
# transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_loader, dataset = get_loader(
# running other acccount
# root_folder = "/content/gdrive/MyDrive/image_captioning/bangla_dataset/final_image_dataset_7468",
# annotation_file = "/content/gdrive/MyDrive/image_captioning/bangla_dataset/all_annotation.csv",
# running same acccount
root_folder = "/content/drive/MyDrive/image_captioning/bangla_dataset/final_image_dataset_7468",
annotation_file = "/content/drive/MyDrive/image_captioning/bangla_dataset/all_annotation.csv",
transform=transform,
num_workers=2,
)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
# save_model = False
save_model = True
train_CNN = False
# Hyperparameters
# embed_size = 256
# hidden_size = 256
# embed_size = 384
# hidden_size = 384
embed_size = 448
hidden_size = 448
vocab_size = len(dataset.vocab)
num_layers = 4
learning_rate = 3e-6
num_epochs = 30
# for tensorboard
writer = SummaryWriter("runs/flickr")
step = 0
# initialize model, loss etc
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Only finetune the CNN
for name, param in model.encoderCNN.inception.named_parameters():
if "fc.weight" in name or "fc.bias" in name:
param.requires_grad = True
else:
param.requires_grad = train_CNN
if load_model:
step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
model.train()
for epoch in range(num_epochs):
# Uncomment the line below to see a couple of test cases
print_examples(model, device, dataset)
if save_model:
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step,
}
save_checkpoint(checkpoint)
for idx, (imgs, captions) in tqdm(
enumerate(train_loader), total=len(train_loader), leave=False
):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions[:-1])
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
writer.add_scalar("Training loss", loss.item(), global_step=step)
step += 1
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
if __name__ == "__main__":
train()