Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/vision into cutmix-mixup
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 28, 2023
2 parents 5e02675 + 8071c17 commit 993f693
Show file tree
Hide file tree
Showing 23 changed files with 547 additions and 355 deletions.
1 change: 0 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ Conversion
v2.PILToTensor
v2.ToImageTensor
ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def show(sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)

fig, ax = plt.subplots()
Expand Down
37 changes: 8 additions & 29 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os

import torch
Expand All @@ -7,25 +6,6 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target


def convert_coco_poly_to_mask(segmentations, height, width):
Expand Down Expand Up @@ -219,7 +199,7 @@ def __getitem__(self, idx):
return img, target


def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
Expand All @@ -232,10 +212,15 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
ann_file = os.path.join(root, ann_file)

if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
Expand All @@ -249,9 +234,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
29 changes: 23 additions & 6 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
import utils
from coco_utils import get_coco, get_coco_kp
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
Expand All @@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):

def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
with_masks = "mask" in args.model
ds = get_coco(
root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes


Expand All @@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
parser.add_argument(
"--dataset",
default="coco",
type=str,
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
)
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down
24 changes: 15 additions & 9 deletions references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand All @@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
return dataset


def get_coco(root, image_set, transforms):
def get_coco(root, image_set, transforms, use_v2=False):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# The 2 "Compose" below achieve the same thing: converting coco detection
# samples into segmentation-compatible samples. They just do it with
# slightly different implementations. We could refactor and unify, but
# keeping them separate helps keeping the v2 version clean
if use_v2:
import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2

transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
else:
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
Expand Down
113 changes: 90 additions & 23 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,106 @@
from collections import defaultdict

import torch
import transforms as T


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import v2_extras

return torchvision.transforms.v2, torchvision.datapoints, v2_extras
else:
import transforms

return transforms, None, None


class SegmentationPresetTrain:
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
def __init__(
self,
*,
base_size,
crop_size,
hflip_prob=0.5,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
backend="pil",
use_v2=False,
):
T, datapoints, v2_extras = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]

trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend(
[
T.RandomCrop(crop_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
transforms += [T.RandomHorizontalFlip(hflip_prob)]

if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]

transforms += [T.RandomCrop(crop_size)]

if backend == "pil":
transforms += [T.PILToTensor()]

if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
]
)
self.transforms = T.Compose(trans)
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]

transforms += [T.Normalize(mean=mean, std=std)]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class SegmentationPresetEval:
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose(
[
T.RandomResize(base_size, base_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
def __init__(
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
):
T, _, _ = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if use_v2:
transforms += [T.Resize(size=(base_size, base_size))]
else:
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]

if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
Loading

0 comments on commit 993f693

Please sign in to comment.