-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
141 lines (119 loc) · 7.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import os
import numpy as np
import scipy.io as scio
import logging
import argparse
from torch.utils.data import DataLoader
from data_loader import TestDataset
from main_net import MainNet
from functions import to_device
from functions import crop_image
from functions import merge_image
from functions import generate_samples
from functions import extract_batch
# Testing settings
parser = argparse.ArgumentParser(description="Content-aware warping for view synthesis")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_source", type=int, default=2, help="Number of source views")
parser.add_argument("--patch_size", type=int, default=46, help="The size of the croped view patch")
parser.add_argument("--band_width", type=int, default=0, help="The width of the epipolar line")
parser.add_argument("--epipolar_length", type=int, default=200, help="The lenghth of the epipolar line")
parser.add_argument("--test_data_path", type=str, default='./Dataset/test_DTU_RGB_18x49_flow_18x49x2x1_6dof_18x49x6_sc_18x49x2.h5', help="Path for loading testing data ")
parser.add_argument("--model_name", type=str, default='dtu_s2.pth', help="loaded model")
# network hyper-parameters
parser.add_argument("--depth_range", nargs='+', type=int, default=[425,900], help="Depth range of the dataset")
parser.add_argument("--D", type=int, default=32, help="The number of depth layers")
parser.add_argument("--cout", type=int, default=256, help="The number of network channels")
opt = parser.parse_args()
# make log dirs
exp_name = 'DTU'
exp_path = os.path.join('./logs',exp_name)
# log infomation
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
log = logging.getLogger()
fh = logging.FileHandler(os.path.join(exp_path,'Testing.log'))
log.addHandler(fh)
logging.info(opt)
if __name__ == '__main__':
# load data
test_dataset = TestDataset(opt)
test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size,shuffle=False)
device = torch.device("cuda:0")
# load model
model=MainNet(opt)
pretrained_dict = torch.load(os.path.join(exp_path,opt.model_name))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
to_device(model,device)
with torch.no_grad():
# inference
for ind_scene,test_scene in enumerate(test_dataloader):
# load samples
scene_name = test_scene['scene_name']
scene_name_list.append(scene_name)
del test_scene['scene_name']
test_scene = to_device(test_scene,device) # scene_name, views, pose_maps, Ks, Rs, Ts
# extract current scene and its name
pose_maps=test_scene['pose_maps'] #[b,t,h,w]
source_clusters=test_scene['source_clusters'] #[b,t,ns]
views=test_scene['views'] #[b,t,h,w,3]
flows=test_scene['flows'] #[b,t,ns,ns-1,2,h,w]
Ks=test_scene['Ks'] #[b,t,3,3]
Rs=test_scene['Rs'] #[b,t,3,3]
Ts=test_scene['Ts'] #[b,t,3]
b,t,h,w = views.shape[:4]
ns = source_clusters.shape[2]
ps = opt.patch_size
neighbor_size = (2*opt.band_width+1)*opt.epipolar_length
ref_esti_views = views.clone() #[b,t,h,w,c]
blended_esti_views = views.clone() #[b,t,h,w,c]
interp_esti_views = views.unsqueeze(2).expand(-1,-1,ns,-1,-1,-1).clone() #[b,t,ns,h,w,c]
target_view_indexes = list(np.arange(1,t-1)) #[1,2,3,...]
################################################################################################################################################################################################################
for ind_t in target_view_indexes:
# extract input
target_sample, source_sample = generate_samples(views, flows, pose_maps, Ks, Rs, Ts, source_clusters, index=ind_t, mode='test') #{view, posemap, K, R, T}
test_samples = [target_sample, source_sample]
# calculate patch positions
_, left_top_xy, coordinate = crop_image(test_samples[0]['view'],ps) #[b,3,patch_size,patch_size,n]
n=coordinate[0]*coordinate[1]
# inference for patches
ref_novel_stack = []
blended_novel_stack = []
interp_novel_stack = []
for ind_n in range(n):
# crop views smaller patches
test_batch = extract_batch(test_samples, opt, left_top_xy[ind_n], mode='test')
ref_novel_patch,blended_novel_patch,interp_novel_patch = model(test_batch, opt.band_width)[:3]
ref_novel_stack.append(ref_novel_patch)
blended_novel_stack.append(blended_novel_patch)
interp_novel_stack.append(interp_novel_patch)
ref_novel_stack = torch.stack(ref_novel_stack,dim=4) #[b,3,patch_size,patch_size,n]
blended_novel_stack = torch.stack(blended_novel_stack,dim=4) #[b,3,patch_size,patch_size,n]
interp_novel_stack = torch.stack(interp_novel_stack,dim=5) #[b,ns,3,patch_size,patch_size,n]
# merge the patches to intact image
ref_novel_view = merge_image(ref_novel_stack,coordinate) #[b,3,h_croped,w_croped]
blended_novel_view = merge_image(blended_novel_stack,coordinate) #[b,3,h_croped,w_croped]
interp_novel_view = merge_image(interp_novel_stack.reshape(b*ns,3,ps,ps,n),coordinate) #[b*ns,3,h_croped,w_croped]
# replace novel view in the sequences
h_croped,w_crop = ref_novel_view.shape[2:4]
ref_esti_views[:,ind_t,0:h_croped,0:w_crop,:] = ref_novel_view.permute(0,2,3,1) #[b,t,h_croped,w_crop,3]
blended_esti_views[:,ind_t,0:h_croped,0:w_crop,:] = blended_novel_view.permute(0,2,3,1) #[b,t,h_croped,w_crop,3]
interp_esti_views[:,ind_t,:,0:h_croped,0:w_crop,:] = interp_novel_view.reshape(b,ns,3,h_croped,w_crop).permute(0,1,3,4,2) #[b,t,ns,h_croped,w_crop,3]
print('View:', ind_t)
################################################################################################################################################################################################################
ref_esti_views = ref_esti_views[0,:,0:h_croped,0:w_crop,:].cpu().numpy() #[t,h_croped,w_croped,c]
blended_esti_views = blended_esti_views[0,:,0:h_croped,0:w_crop,:].cpu().numpy() #[t,h_croped,w_croped,c]
interp_esti_views = interp_esti_views[0,:,:,0:h_croped,0:w_crop,:].cpu().numpy() #[t,ns,h_croped,w_croped,c]
gt_views = views[0,:,0:h_croped,0:w_crop,:].cpu().numpy() #[t,h_croped,w_croped,c]
# save
scio.savemat(os.path.join(exp_path,scene_name[0]+'_ref.mat'),
{'lf_recons':ref_esti_views}) #[t,h_croped,w_croped,c]
scio.savemat(os.path.join(exp_path,scene_name[0]+'_blended.mat'),
{'lf_recons':blended_esti_views}) #[t,h_croped,w_croped,c]
scio.savemat(os.path.join(exp_path,scene_name[0]+'_interp.mat'),
{'lf_recons':interp_esti_views}) #[t,h_croped,w_croped,c]