Skip to content

Commit

Permalink
Properly handle maskrcnn in detection ref
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 13, 2023
1 parent bb3aae7 commit 28293dc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
17 changes: 7 additions & 10 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __getitem__(self, idx):
return img, target


def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
def get_coco(root, image_set, transforms, args, mode="instances"):

Check warning on line 222 in references/detection/coco_utils.py

View workflow job for this annotation

GitHub Actions / bc

Function get_coco: args was added and is now required

Check warning on line 222 in references/detection/coco_utils.py

View workflow job for this annotation

GitHub Actions / bc

Function get_coco: use_v2 was removed
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
Expand All @@ -231,11 +231,14 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

if use_v2:
if args.use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
target_keys = ["boxes", "labels", "image_id"]
if args.with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
Expand All @@ -249,9 +252,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
31 changes: 25 additions & 6 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
import utils
from coco_utils import get_coco, get_coco_kp
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
Expand All @@ -42,10 +42,10 @@ def copypaste_collate_fn(batch):

def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
ds = get_coco(
root=args.data_path, image_set=image_set, transforms=get_transform(is_train, args), args=args, mode=mode
)
return ds, num_classes


Expand All @@ -68,7 +68,7 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco_kp for Keypoint detection")
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -164,13 +164,32 @@ def get_args_parser(add_help=True):

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
parser.add_argument(
"--with-masks",
action="store_true",
help=(
"Whether the dataset should return masks. Only relevant when --use-v2 is passed. "
"True by default when using mask_rcnn."
),
)

return parser


def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")

if "mask" in args.model:
args.with_masks = True

print(args.model)

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down

0 comments on commit 28293dc

Please sign in to comment.