diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 313faacdb7c..7cf19d39dc9 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,4 +1,3 @@ -import copy import os import torch @@ -10,24 +9,6 @@ from torchvision.datasets import wrap_dataset_for_transforms_v2 -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target - - def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: @@ -219,7 +200,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances", use_v2=False): +def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -233,9 +214,12 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): if use_v2: dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) - # TODO: need to update target_keys to handle masks for segmentation! - dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) + target_keys = ["boxes", "labels", "image_id"] + if with_masks: + target_keys += ["masks"] + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) else: + # TODO: handle with_masks for V1? t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) @@ -249,9 +233,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset - - -def get_coco_kp(root, image_set, transforms, use_v2=False): - if use_v2: - raise ValueError("KeyPoints aren't supported by transforms V2 yet.") - return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/train.py b/references/detection/train.py index db86f33aaa9..892ffbbbc1c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -28,7 +28,7 @@ import torchvision.models.detection import torchvision.models.detection.mask_rcnn import utils -from coco_utils import get_coco, get_coco_kp +from coco_utils import get_coco from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode @@ -42,10 +42,16 @@ def copypaste_collate_fn(batch): def get_dataset(is_train, args): image_set = "train" if is_train else "val" - paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[args.dataset] - - ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) + num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + with_masks = "mask" in args.model + ds = get_coco( + root=args.data_path, + image_set=image_set, + transforms=get_transform(is_train, args), + mode=mode, + use_v2=args.use_v2, + with_masks=with_masks, + ) return ds, num_classes @@ -68,7 +74,12 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name") + parser.add_argument( + "--dataset", + default="coco", + type=str, + help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection", + ) parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -171,6 +182,12 @@ def get_args_parser(add_help=True): def main(args): if args.backend.lower() == "datapoint" and not args.use_v2: raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.dataset not in ("coco", "coco_kp"): + raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") + if "keypoint" in args.model and args.dataset != "coco_kp": + raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") + if args.dataset == "coco_kp" and args.use_v2: + raise ValueError("KeyPoint detection doesn't support V2 transforms yet") if args.output_dir: utils.mkdir(args.output_dir)