Skip to content

Commit

Permalink
More comments
Browse files Browse the repository at this point in the history
  • Loading branch information
atroyn committed Dec 14, 2022
1 parent 67e5e1f commit e5f508b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
3 changes: 3 additions & 0 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def run(
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
except ImportError:
thop = None


# We modify the original Detect class to output the embeddings along with the predictions
class Detect(nn.Module):
# YOLOv5 Detect head for detection models
stride = None # strides computed during build
Expand Down Expand Up @@ -77,7 +77,7 @@ def forward(self, x):
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
embeddings.append(x[i].view(bs, self.na * nx * ny, self.no))
embeddings.append(x[i].view(bs, self.na * nx * ny, self.no)) # The embeddings are the raw output of the last conv layer, in the same shape as the predictions

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x, torch.cat(embeddings, 1)) if self.with_embeddings else (torch.cat(z, 1), x)

Expand Down
3 changes: 2 additions & 1 deletion utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def create_dataloader(path,
worker_init_fn=seed_worker,
generator=generator), dataset

# A new dataloader for cases where we have input images but no labels.
def create_imageloader(path, imgsz, batch_size, stride, workers):
dataset = LoadImages(path, imgsz, stride=int(stride), auto=False, n_workers=workers)
batch_size = min(batch_size, len(dataset))
Expand Down Expand Up @@ -241,7 +242,7 @@ def __next__(self):
self.frame += 1
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s


# An iterable dataset that loads images, compatible with the IterableDataset interface.
class LoadImages(IterableDataset):
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1, n_workers=0):
Expand Down
18 changes: 9 additions & 9 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,8 @@ def clip_segments(boxes, shape):
boxes[:, 0] = boxes[:, 0].clip(0, shape[1]) # x
boxes[:, 1] = boxes[:, 1].clip(0, shape[0]) # y


# We modify the original non_max_suppression function to return the embeddings as well
# In practice, this means making sure that they are extracted and filtered alongside the predictions
def non_max_suppression(
prediction,
conf_thres=0.25,
Expand All @@ -866,7 +867,6 @@ def non_max_suppression(
embedding = prediction[2] # Last part of the tuple has raw conv. output
prediction = prediction[0] # select only inference output


device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
Expand Down Expand Up @@ -901,7 +901,7 @@ def non_max_suppression(
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
if with_embeddings:
e = embedding[xi][xc[xi]]
e = embedding[xi][xc[xi]] # Filter to the same indices as the predictions

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
Expand All @@ -928,18 +928,18 @@ def non_max_suppression(
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
if with_embeddings:
e = e[i]
e = e[i] # Filter to the same indices as the predictions
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
if with_embeddings:
e = e[conf.view(-1) > conf_thres]
e = e[conf.view(-1) > conf_thres] # Filter to the same indices as the predictions. Note that no concatenation is needed here.

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
if with_embeddings:
e = e[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
e = e[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] # Filter to the same indices as the predictions

# Apply finite constraint
# if not torch.isfinite(x).all():
Expand All @@ -952,11 +952,11 @@ def non_max_suppression(
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
if with_embeddings:
e = e[x[:, 4].argsort(descending=True)[:max_nms]]
e = e[x[:, 4].argsort(descending=True)[:max_nms]] # Filter to the same indices as the predictions
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
if with_embeddings:
e = e[x[:, 4].argsort(descending=True)]
e = e[x[:, 4].argsort(descending=True)] # Filter to the same indices as the predictions

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
Expand All @@ -974,7 +974,7 @@ def non_max_suppression(

output[xi] = x[i]
if with_embeddings:
embedding_output[xi] = e[i]
embedding_output[xi] = e[i] # Assign the embeddings to the output
if mps:
output[xi] = output[xi].to(device)
if with_embeddings:
Expand Down

0 comments on commit e5f508b

Please sign in to comment.