From 28293dc5df923907f64fe00a07d5a18c486f005c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Jul 2023 15:55:15 +0000 Subject: [PATCH] Properly handle maskrcnn in detection ref --- references/detection/coco_utils.py | 17 +++++++--------- references/detection/train.py | 31 ++++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 313faacdb7c..88b8c069e41 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -219,7 +219,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances", use_v2=False): +def get_coco(root, image_set, transforms, args, mode="instances"): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -231,11 +231,14 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - if use_v2: + if args.use_v2: dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) - # TODO: need to update target_keys to handle masks for segmentation! - dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) + target_keys = ["boxes", "labels", "image_id"] + if args.with_masks: + target_keys += ["masks"] + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) else: + # TODO: handle with_masks for V1? t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) @@ -249,9 +252,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset - - -def get_coco_kp(root, image_set, transforms, use_v2=False): - if use_v2: - raise ValueError("KeyPoints aren't supported by transforms V2 yet.") - return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/train.py b/references/detection/train.py index db86f33aaa9..d722c63a13a 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -28,7 +28,7 @@ import torchvision.models.detection import torchvision.models.detection.mask_rcnn import utils -from coco_utils import get_coco, get_coco_kp +from coco_utils import get_coco from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode @@ -42,10 +42,10 @@ def copypaste_collate_fn(batch): def get_dataset(is_train, args): image_set = "train" if is_train else "val" - paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[args.dataset] - - ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) + num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + ds = get_coco( + root=args.data_path, image_set=image_set, transforms=get_transform(is_train, args), args=args, mode=mode + ) return ds, num_classes @@ -68,7 +68,7 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name") + parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco_kp for Keypoint detection") parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -164,6 +164,14 @@ def get_args_parser(add_help=True): parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") + parser.add_argument( + "--with-masks", + action="store_true", + help=( + "Whether the dataset should return masks. Only relevant when --use-v2 is passed. " + "True by default when using mask_rcnn." + ), + ) return parser @@ -171,6 +179,17 @@ def get_args_parser(add_help=True): def main(args): if args.backend.lower() == "datapoint" and not args.use_v2: raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.dataset not in ("coco", "coco_kp"): + raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") + if "keypoint" in args.model and args.dataset != "coco_kp": + raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") + if args.dataset == "coco_kp" and args.use_v2: + raise ValueError("KeyPoint detection doesn't support V2 transforms yet") + + if "mask" in args.model: + args.with_masks = True + + print(args.model) if args.output_dir: utils.mkdir(args.output_dir)