-
Notifications
You must be signed in to change notification settings - Fork 0
/
adaface.py
157 lines (132 loc) · 6.13 KB
/
adaface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import net
import torch
import os
from face_alignment import align
import numpy as np
import cv2
from PIL import Image
import argparse
import sys
import time
sys_path = os.path.dirname(__file__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False) # for 메모리
adaface_models = {
'ir_50': os.path.join(sys_path, "pretrained/adaface_ir50_ms1mv2.ckpt"),
}
def load_pretrained_model(architecture='ir_50'):
# load model and pretrained statedict
assert architecture in adaface_models.keys()
model = net.build_model(architecture)
statedict = torch.load(adaface_models[architecture])['state_dict']
model_statedict = {key[6:]:val for key, val in statedict.items() if key.startswith('model.')}
model.load_state_dict(model_statedict)
model.eval()
return model
def to_input(pil_rgb_image, RGB=False):
np_img = np.array(pil_rgb_image)
if RGB:
brg_img = ((np_img / 255.) - 0.5) / 0.5 # rgb 기준 ..?
else:
brg_img = ((np_img[:,:,::-1] / 255.) - 0.5) / 0.5 # bgr 기준 ..?
tensor = torch.tensor(np.expand_dims(brg_img.transpose(2,0,1),axis=0)).float()
return tensor
def load_dataset():
# 데이터 로드
database_dir = os.path.join(sys_path, 'embed/features.pt')
if (os.path.exists(database_dir)):
known_face_encodings = torch.load(database_dir).to(device)
known_face_names = torch.load(os.path.join(sys_path, 'embed/ids.pt'))
print("knwon face list: ", len(known_face_names))
return known_face_encodings, known_face_names
else:
print("Error: Face database not found")
sys.exit(1)
model = load_pretrained_model('ir_50') # .to(device)
print('face recongnition model loaded')
# 데이터 로드
known_face_encodings, known_face_names = load_dataset()
def store_embedding(opt):
# 모든 얼굴의 임베딩 계산
model = load_pretrained_model(opt.fr_weight) # .to(device)
features = []
ids = []
for fname in sorted(os.listdir(opt.dataset)):
path = os.path.join(opt.dataset, fname)
aligned_rgb_img = align.get_aligned_face(path)
bgr_tensor_input = to_input(aligned_rgb_img)
feature, _ = model(bgr_tensor_input)
# features[fname.split('.')[0]] = feature
features.append(feature)
ids.append(fname.split('.')[0])
# print(features)
if not os.path.exists('embed'):
os.makedirs('embed')
# Embeddings와 passage_ids를 저장
features = torch.squeeze(torch.stack(features), dim=1)
# print(features.shape)
torch.save(features, os.path.join(sys_path, 'embed/features.pt'))
torch.save(ids, os.path.join(sys_path, 'embed/ids.pt'))
print(f"얼굴 임베딩 벡터 저장 완료(known face 개수: {len(ids)})")
return features, ids
def inference(opt, frame):
pil_im = Image.fromarray(frame).convert('RGB')
face_encodings = []
## 1. 얼굴 feature 추출
aligned_rgb_img, bboxes = align.get_aligned_face_for_webcam('', pil_im, opt.max_obj)
bboxes = [[int(xy) for (xy) in bbox] for bbox in bboxes]
for img in aligned_rgb_img:
bgr_tensor_input = to_input(img)
face_encoding, _ = model(bgr_tensor_input)
face_encodings.append(face_encoding)
# if not face_encodings:
# continue
face_names = []
if len(face_encodings) > 0:
## 2. 얼굴 유사도 측정 with tensor
# start_time = time.time() # 연산에 대한 실행 시간(start) check
face_encodings = torch.squeeze(torch.stack(face_encodings), dim=1).to(device) # torch.squeeze(torch.stack(face_encodings), dim=1) # torch.squeeze()
with torch.no_grad():
face_distances = torch.matmul(face_encodings, known_face_encodings.T)
best_match_index = torch.argmax(face_distances, dim=1)
thresh = opt.thresh
face_names = ["unknown" if torch.any(face_distances[i][idx] < thresh) else known_face_names[idx] for i, idx in enumerate(best_match_index)]
# face_names = [known_face_names[idx] for idx in best_match_index] # threshold 없는 경우 ('unkown' 처리 안한 경우)
# end_time = time.time() # 연산에 대한 실행 시간(end) check
# print("Execution Time:", (end_time - start_time), "sec") # 실행 시간 0.0003 ~
return bboxes, face_names
def run_video(opt):
video_capture = cv2.VideoCapture(opt.video) # os.path.join(sys_path, "video/iAM.mp4"
while True:
# Grab a single frame of video
ret, frame = video_capture.read()
# height, width = frame.shape[:2]
if not ret:
print("Warning: no frame")
break
frame = cv2.flip(frame, 1)
bboxes, face_names = inference(opt, frame)
## 3. bbox 시각화
if (len(face_names) > 0):
for (x1, y1, x2, y2, _), f_name in zip(bboxes, face_names):
cv2.rectangle(frame,(x1, y1), (x2, y2),(0, 0, 255), 1)
cv2.rectangle(frame, (x1, y2 - 30), (x2, y2), (0, 0, 255), cv2.FILLED)
font = cv2.FONT_HERSHEY_DUPLEX
cv2.putText(frame, f_name, (x1 + 6, y2 - 6), font, .5, (0, 0, 0), 1)
# Display the resulting image
cv2.imshow('Video', frame)
# Hit 'q' on the keyboard to quit!
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# Release handle to the webcam
video_capture.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=int, default=1, help='0: 임베딩 저장 1: webcam/video으로 run_video)')
parser.add_argument('--video', type=str, default='0', help='0: webcam 또는 "video/iAM.mp4" 특정 비디오 path 경로')
parser.add_argument('--fr_weight', type=str, default='ir_50', help='face recognition weight')
parser.add_argument('--thresh', nargs='+', type=str, default=.2, help='unknown confidence < .2')
parser.add_argument('--max_obj', type=int, default=6, help='detect 가능한 최대 얼굴의 개수')
parser.add_argument('--dataset', type=str, default='face_dataset/test', help='face dataset의 경로 (known face dataset)')
opt = parser.parse_args()