Skip to content

Commit

Permalink
Merge branch 'main' into please_dont_modify_this_branch_unless_you_ar…
Browse files Browse the repository at this point in the history
…e_just_merging_with_main
  • Loading branch information
NicolasHug authored Jul 28, 2023
2 parents ee79463 + 3966f95 commit a9cdd93
Show file tree
Hide file tree
Showing 31 changed files with 1,118 additions and 575 deletions.
17 changes: 16 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ Conversion
v2.PILToTensor
v2.ToImageTensor
ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
Expand Down Expand Up @@ -262,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

Cutmix - Mixup
--------------

Cutmix and Mixup are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader,
or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Cutmix
v2.Mixup

.. _functional_transforms:

Functional Transforms
Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_cutmix_mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

"""
===========================
How to use Cutmix and Mixup
===========================
TODO
"""
2 changes: 1 addition & 1 deletion gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def show(sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)

fig, ax = plt.subplots()
Expand Down
19 changes: 9 additions & 10 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -218,18 +218,17 @@ def main(args):
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
return mixup_cutmix(*default_collate(batch))

else:
collate_fn = default_collate

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
23 changes: 23 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from typing import Tuple

import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F


def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)

mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None

return transforms_module.RandomChoice(mixup_cutmix)


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
37 changes: 8 additions & 29 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os

import torch
Expand All @@ -7,25 +6,6 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target


def convert_coco_poly_to_mask(segmentations, height, width):
Expand Down Expand Up @@ -219,7 +199,7 @@ def __getitem__(self, idx):
return img, target


def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
Expand All @@ -232,10 +212,15 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
ann_file = os.path.join(root, ann_file)

if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
Expand All @@ -249,9 +234,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
29 changes: 23 additions & 6 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
import utils
from coco_utils import get_coco, get_coco_kp
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
Expand All @@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):

def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
with_masks = "mask" in args.model
ds = get_coco(
root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes


Expand All @@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
parser.add_argument(
"--dataset",
default="coco",
type=str,
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
)
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down
24 changes: 15 additions & 9 deletions references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand All @@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
return dataset


def get_coco(root, image_set, transforms):
def get_coco(root, image_set, transforms, use_v2=False):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# The 2 "Compose" below achieve the same thing: converting coco detection
# samples into segmentation-compatible samples. They just do it with
# slightly different implementations. We could refactor and unify, but
# keeping them separate helps keeping the v2 version clean
if use_v2:
import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2

transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
else:
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
Expand Down
Loading

0 comments on commit a9cdd93

Please sign in to comment.