Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle maskrcnn and keypoints w.r.t. V2 in detection references #7742

Merged
merged 8 commits into from
Jul 27, 2023
34 changes: 6 additions & 28 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import os

Check warning on line 1 in references/detection/coco_utils.py

View workflow job for this annotation

GitHub Actions / bc

Function get_coco_kp: function deleted

import torch
import torch.utils.data
Expand All @@ -10,24 +9,6 @@
from torchvision.datasets import wrap_dataset_for_transforms_v2


class FilterAndRemapCocoCategories:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
Expand Down Expand Up @@ -219,7 +200,7 @@
return img, target


def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
Expand All @@ -233,9 +214,12 @@

if use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
Expand All @@ -249,9 +233,3 @@
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
29 changes: 23 additions & 6 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
import utils
from coco_utils import get_coco, get_coco_kp
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
Expand All @@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):

def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
with_masks = "mask" in args.model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the get_masks parameter of get_coco() and just make it part of the mode parameter, i.e. allow "instances", "instances_masks", and "keypoints".

ooooor just completely change the parametrization of get_coco() to something more friendly.

In addition to that, the way we magically specify keypoints or masks (through dataset name or through model name) is pretty terrible right now.

I feel like we should address that for good, but I'd prefer doing it in another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that it would be better that args.dataset could parametrize data type instead of the model name.
args.dataset can be "coco, "coco_kp" and "coco_masks", for example.

ds = get_coco(
root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes


Expand All @@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
parser.add_argument(
"--dataset",
default="coco",
type=str,
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
)
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down
Loading