-
Notifications
You must be signed in to change notification settings - Fork 6
/
predict.py
155 lines (134 loc) · 5.06 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
# coding: utf-8
from PIL import Image
import cv2
from path import Path
from utils.datasets import SlippyMapTilesConcatenation
import collections
import torch
import torch.backends.cudnn
from torch.nn import DataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import torchvision.transforms.functional as tf
import segmentation_models_pytorch as smp
from utils.loss import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
from utils.transforms import (
JointCompose,
JointTransform,
JointRandomHorizontalFlip,
JointRandomRotation,
ConvertImageMode,
ImageToTensor,
MaskToTensor,
)
from torchvision.transforms import Resize, CenterCrop, Normalize
from utils.metrics import Metrics
from models.segnet.segnet import segnet
from models.unet.unet import UNet
from modeling.deeplab import DeepLab
import random
import os
import tqdm
import json
import numpy as np
## need to create a file to store temp pictures
try:
shutil.rmtree('temp_pic') #递归删除文件夹
except:
pass
os.makedirs('temp_pic')
path = './temp_pic/'
device = 'cuda'
# predict on one model
model = torch.load(f'model/0514pspnet_50_epoch.pth')
# give the picture you want to predict
file_name = f'/home/shiyi/beshe/gaoxin_map/second_dataset/part1_500.png'
# give the name you want to store
save_dir = '0514predict1.png'
## use model to predict
def predict(model):
model.eval()
result = []
for images in tqdm.tqdm(test_loader):
images = images.to(device)
outputs = model(images)
probs = torch.max(outputs,1)[1]
result.append(probs.cpu().numpy().reshape(512,512))
return result
def input_and_output(pic_path,model,generate_data):
"""
args:
pic_path : the picture you want to predict
model : the model you want to predict
note:
step one : generate some pictures from one picture
step two : predict from the images generated by step one
"""
stride = 512
image_size = 512
image = cv2.imread(f'{pic_path}')
h,w,_ = image.shape
padding_h = (h//stride + 1) * stride
padding_w = (w//stride + 1) * stride
padding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)
padding_img[0:h,0:w,:] = image[:,:,:]
padding_img = np.array(padding_img)
# print ('src:',padding_img.shape)
mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)
if generate_data== False:
result = predict(model)
map_list = [str(i.name) for i in Path('temp_pic').files()]
for i in range(padding_h//stride):
for j in range(padding_w//stride):
crop = padding_img[i*stride:i*stride+image_size,j*stride:j*stride+image_size , :]
ch,cw,_ = crop.shape
if generate_data == True:
cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
if generate_data== False:
mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = result[map_list.index(f'{i}_{j}.png')]
return mask_whole[:image.shape[0],:image.shape[1]]
def get_dataset_loaders( workers):
target_size = 512
batch_size = 1
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transform = JointCompose(
[
JointTransform(Resize(target_size, Image.BILINEAR), Resize(target_size, Image.NEAREST)),
JointTransform(CenterCrop(target_size), CenterCrop(target_size)),
JointRandomHorizontalFlip(0.5),
JointRandomRotation(0.5, 90),
JointRandomRotation(0.5, 90),
JointRandomRotation(0.5, 90),
JointTransform(ImageToTensor(), MaskToTensor()),
JointTransform(Normalize(mean=mean, std=std), None),
]
)
test_dataset = SlippyMapTilesConcatenation(
os.path.join(path),'./' , transform,debug = False,test = True
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
return test_loader
def data_for_vote():
# """this function is set up for voting. If you only want to use one model to predict please ignore it"""
if not os.path.exists('unet_predict'):
os.makedirs('unet_predict')
if not os.path.exists('segnet_predict'):
os.makedirs('segnet_predict')
if not os.path.exists('pspnet_predict'):
os.makedirs('pspnet_predict')
for model_name in ['pspnet','segnet','unet']
for i in range(9,14):
file_name = f'/home/shiyi/beshe/gaoxin_map/second_dataset/part{i}_500.png'
model = torch.load(f'model/{model_name}.pth')
input_and_output(file_name,model,generate_data= True)
test_loader = get_dataset_loaders(5)
mask_result = input_and_output(file_name,model,generate_data= False)
cv2.imwrite(f'{model_name}/0514predict{i}.png',mask_result)
### predict on one picture
input_and_output(file_name,model,generate_data= True)
test_loader = get_dataset_loaders(5)
mask_result = input_and_output(file_name,model,generate_data= False)
cv2.imwrite(f'{save_dir}',mask_result)