From 6443e6ace85d89a27a8ea699f2a3d42b07262bbd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Jul 2023 11:32:31 +0000 Subject: [PATCH 1/7] Add --backend support to detection refs --- references/classification/presets.py | 47 ++++----- references/detection/presets.py | 142 +++++++++++++++++---------- references/detection/train.py | 12 ++- references/detection/transforms.py | 9 +- 4 files changed, 131 insertions(+), 79 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 9970ee57730..9b53f0ccd5d 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -15,6 +15,9 @@ def get_module(use_v2): class ClassificationPresetTrain: + # Note: this transform assumes that the input to forward() are always PIL + # images, regardless of the backend parameter. We may change that in the + # future though, if we change the output type from the dataset. def __init__( self, *, @@ -30,42 +33,42 @@ def __init__( backend="pil", use_v2=False, ): - module = get_module(use_v2) + T = get_module(use_v2) transforms = [] backend = backend.lower() if backend == "tensor": - transforms.append(module.PILToTensor()) + transforms.append(T.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") - transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) + transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: - transforms.append(module.RandomHorizontalFlip(hflip_prob)) + transforms.append(T.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": - transforms.append(module.TrivialAugmentWide(interpolation=interpolation)) + transforms.append(T.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity)) + transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity)) else: - aa_policy = module.AutoAugmentPolicy(auto_augment_policy) - transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = T.AutoAugmentPolicy(auto_augment_policy) + transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation)) if backend == "pil": - transforms.append(module.PILToTensor()) + transforms.append(T.PILToTensor()) transforms.extend( [ - module.ConvertImageDtype(torch.float), - module.Normalize(mean=mean, std=std), + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), ] ) if random_erase_prob > 0: - transforms.append(module.RandomErasing(p=random_erase_prob)) + transforms.append(T.RandomErasing(p=random_erase_prob)) - self.transforms = module.Compose(transforms) + self.transforms = T.Compose(transforms) def __call__(self, img): return self.transforms(img) @@ -83,28 +86,28 @@ def __init__( backend="pil", use_v2=False, ): - module = get_module(use_v2) + T = get_module(use_v2) transforms = [] backend = backend.lower() if backend == "tensor": - transforms.append(module.PILToTensor()) + transforms.append(T.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") transforms += [ - module.Resize(resize_size, interpolation=interpolation, antialias=True), - module.CenterCrop(crop_size), + T.Resize(resize_size, interpolation=interpolation, antialias=True), + T.CenterCrop(crop_size), ] if backend == "pil": - transforms.append(module.PILToTensor()) + transforms.append(T.PILToTensor()) transforms += [ - module.ConvertImageDtype(torch.float), - module.Normalize(mean=mean, std=std), + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), ] - self.transforms = module.Compose(transforms) + self.transforms = T.Compose(transforms) def __call__(self, img): return self.transforms(img) diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..f2098630985 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,73 +1,109 @@ +from collections import defaultdict + import torch -import transforms as T +import transforms as reference_transforms + + +def get_modules(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.datapoints + import torchvision.transforms.v2 + + return torchvision.transforms.v2, torchvision.datapoints + else: + return reference_transforms, None class DetectionPresetTrain: - def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): + def __init__( + self, + *, + data_augmentation, + hflip_prob=0.5, + mean=(123.0, 117.0, 104.0), + backend="pil", + use_v2=False, + ): + + T, datapoints = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + if data_augmentation == "hflip": - self.transforms = T.Compose( - [ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [T.RandomHorizontalFlip(p=hflip_prob)] elif data_augmentation == "lsj": - self.transforms = T.Compose( - [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.ScaleJitter(target_size=(1024, 1024), antialias=True), + # TODO: FixedSizeCrop below doesn't work on tensors! + reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "multiscale": - self.transforms = T.Compose( - [ - T.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssd": - self.transforms = T.Compose( - [ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean) + transforms += [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=fill), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssdlite": - self.transforms = T.Compose( - [ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') + if backend == "pil": + # Note: we could just convert to pure tensors even in v2. + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + + transforms += [T.ConvertImageDtype(torch.float)] + + if use_v2: + transforms += [ + T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), + T.SanitizeBoundingBox(), + ] + + self.transforms = T.Compose(transforms) + def __call__(self, img, target): return self.transforms(img, target) class DetectionPresetEval: - def __init__(self): - self.transforms = T.Compose( - [ - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + def __init__(self, backend="pil", use_v2=False): + T, _ = get_modules(use_v2) + transforms = [] + backend = backend.lower() + # Conversion may look a bit weird but the assumption of this transform is that the input is always a PIL image + # TODO: Is that still true when using v2, from the dataset??????? + if backend == "pil": + # Note: we could just convert to pure tensors even in v2? + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + elif backend == "tensor": + transforms += [T.PILToTensor()] + elif backend == "datapoint": + transforms += [T.ToImageTensor()] + else: + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + transforms += [T.ConvertImageDtype(torch.float)] + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..6a4069e2fca 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -50,13 +50,15 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: - return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) + return presets.DetectionPresetTrain( + data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2 + ) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target: (trans(img), target) else: - return presets.DetectionPresetEval() + return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2) def get_args_parser(add_help=True): @@ -159,10 +161,16 @@ def get_args_parser(add_help=True): help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.", ) + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") + return parser def main(args): + if args.backend.lower() == "datapoint" and not args.use_v2: + raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..65cf4e83592 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -293,11 +293,13 @@ def __init__( target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias=True, ): super().__init__() self.target_size = target_size self.scale_range = scale_range self.interpolation = interpolation + self.antialias = antialias def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None @@ -315,14 +317,17 @@ def forward( new_width = int(orig_width * r) new_height = int(orig_height * r) - image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias) if target is not None: target["boxes"][:, 0::2] *= new_width / orig_width target["boxes"][:, 1::2] *= new_height / orig_height if "masks" in target: target["masks"] = F.resize( - target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + target["masks"], + [new_height, new_width], + interpolation=InterpolationMode.NEAREST, + antialias=self.antialias, ) return image, target From d10dd565544b5b526f40333df89f623f24f73adc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Jul 2023 13:28:39 +0000 Subject: [PATCH 2/7] Add --use-v2 support to detection refs --- references/detection/coco_utils.py | 35 +++++++++++-------- references/detection/engine.py | 4 +-- references/detection/group_by_aspect_ratio.py | 4 ++- references/detection/presets.py | 2 -- references/detection/train.py | 8 ++--- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 38c8279c35e..7b42faad1e8 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -7,6 +7,7 @@ import transforms as T from pycocotools import mask as coco_mask from pycocotools.coco import COCO +from torchvision.datasets import wrap_dataset_for_transforms_v2 class FilterAndRemapCocoCategories: @@ -49,7 +50,6 @@ def __call__(self, image, target): w, h = image.size image_id = target["image_id"] - image_id = torch.tensor([image_id]) anno = target["annotations"] @@ -126,10 +126,6 @@ def _has_valid_annotation(anno): return True return False - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -196,12 +192,15 @@ def convert_to_coco_api(ds): def get_coco_api_from_dataset(dataset): + # FIXME: This is... awful? for _ in range(10): if isinstance(dataset, torchvision.datasets.CocoDetection): break if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset - if isinstance(dataset, torchvision.datasets.CocoDetection): + if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( + getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection + ): return dataset.coco return convert_to_coco_api(dataset) @@ -220,7 +219,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances"): +def get_coco(root, image_set, transforms, mode="instances", use_v2=False): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -228,17 +227,22 @@ def get_coco(root, image_set, transforms, mode="instances"): # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) } - t = [ConvertCocoPolysToMask()] - - if transforms is not None: - t.append(transforms) - transforms = T.Compose(t) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = CocoDetection(img_folder, ann_file, transforms=transforms) + # TODO: cleanup + if use_v2: + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + # TODO: need to update target_keys to handle masks for segmentation! + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) + else: + t = [ConvertCocoPolysToMask()] + if transforms is not None: + t.append(transforms) + transforms = T.Compose(t) + + dataset = CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) @@ -248,5 +252,6 @@ def get_coco(root, image_set, transforms, mode="instances"): return dataset -def get_coco_kp(root, image_set, transforms): +def get_coco_kp(root, image_set, transforms, use_v2): + # TODO: handle use_v2 return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e5d55f189d..0e9bfffdf8a 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device): outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {target["image_id"]: output for target, output in zip(targets, outputs)} evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index d12e14b540c..d4a44724899 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None): if hasattr(dataset, "get_height_and_width"): return _compute_aspect_ratios_custom_dataset(dataset, indices) - if isinstance(dataset, torchvision.datasets.CocoDetection): + if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( + getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection + ): return _compute_aspect_ratios_coco_dataset(dataset, indices) if isinstance(dataset, torchvision.datasets.VOCDetection): diff --git a/references/detection/presets.py b/references/detection/presets.py index f2098630985..cc1afe9cfef 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -90,8 +90,6 @@ def __init__(self, backend="pil", use_v2=False): T, _ = get_modules(use_v2) transforms = [] backend = backend.lower() - # Conversion may look a bit weird but the assumption of this transform is that the input is always a PIL image - # TODO: Is that still true when using v2, from the dataset??????? if backend == "pil": # Note: we could just convert to pure tensors even in v2? transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] diff --git a/references/detection/train.py b/references/detection/train.py index 6a4069e2fca..b874dfb49f3 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -40,11 +40,11 @@ def copypaste_collate_fn(batch): return copypaste(*utils.collate_fn(batch)) -def get_dataset(name, image_set, transform, data_path): +def get_dataset(name, image_set, transform, data_path, use_v2): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] - ds = ds_fn(p, image_set=image_set, transforms=transform) + ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=use_v2) return ds, num_classes @@ -185,8 +185,8 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) + dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path, args.use_v2) + dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path, args.use_v2) print("Creating data loaders") if args.distributed: From f956d012ce54ad0907e6c15443ff99058157305e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Jul 2023 13:29:24 +0000 Subject: [PATCH 3/7] remove comment --- references/detection/coco_utils.py | 1 - references/detection/utils.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 7b42faad1e8..06b21021b84 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -231,7 +231,6 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - # TODO: cleanup if use_v2: dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) # TODO: need to update target_keys to handle masks for segmentation! diff --git a/references/detection/utils.py b/references/detection/utils.py index f73915580f7..f67226ffedb 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -262,9 +262,9 @@ def init_distributed_mode(args): args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() else: print("Not using distributed mode") args.distributed = False From 06ab751b2f43fae8ff8e9dd73a87c4638aae73d1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Jul 2023 13:29:51 +0000 Subject: [PATCH 4/7] uuguuguguguuuu --- references/detection/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/references/detection/utils.py b/references/detection/utils.py index f67226ffedb..f73915580f7 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -262,9 +262,9 @@ def init_distributed_mode(args): args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - # elif "SLURM_PROCID" in os.environ: - # args.rank = int(os.environ["SLURM_PROCID"]) - # args.gpu = args.rank % torch.cuda.device_count() + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() else: print("Not using distributed mode") args.distributed = False From c6913d254a62c22887b4c47ebd2c14c2211f106d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Jul 2023 13:39:32 +0000 Subject: [PATCH 5/7] remove TODO --- references/detection/coco_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 06b21021b84..313faacdb7c 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -251,6 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): return dataset -def get_coco_kp(root, image_set, transforms, use_v2): - # TODO: handle use_v2 +def get_coco_kp(root, image_set, transforms, use_v2=False): + if use_v2: + raise ValueError("KeyPoints aren't supported by transforms V2 yet.") return get_coco(root, image_set, transforms, mode="person_keypoints") From 72da6553e95d451ba17463436e452aaf6a6caca8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 13 Jul 2023 10:40:31 +0200 Subject: [PATCH 6/7] use antialias=True and consistency test --- test/test_transforms_v2_consistency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index bf297473bc2..102afdb37a9 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1133,7 +1133,7 @@ def make_label(extra_dims, categories): {"with_mask": False}, ), (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), - (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), + (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}), ( det_transforms.RandomShortestSize( min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 From 56dc4314cf443c9724da67b6bd711eaaa4109da6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Jul 2023 14:45:18 +0000 Subject: [PATCH 7/7] clean up parameter passing --- references/detection/presets.py | 2 ++ references/detection/train.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/references/detection/presets.py b/references/detection/presets.py index cc1afe9cfef..120f079afc8 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -16,6 +16,8 @@ def get_modules(use_v2): class DetectionPresetTrain: + # Note: this transform assumes that the input to forward() are always PIL + # images, regardless of the backend parameter. def __init__( self, *, diff --git a/references/detection/train.py b/references/detection/train.py index b874dfb49f3..db86f33aaa9 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -40,16 +40,17 @@ def copypaste_collate_fn(batch): return copypaste(*utils.collate_fn(batch)) -def get_dataset(name, image_set, transform, data_path, use_v2): - paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[name] +def get_dataset(is_train, args): + image_set = "train" if is_train else "val" + paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} + p, ds_fn, num_classes = paths[args.dataset] - ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=use_v2) + ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) return ds, num_classes -def get_transform(train, args): - if train: +def get_transform(is_train, args): + if is_train: return presets.DetectionPresetTrain( data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2 ) @@ -185,8 +186,8 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path, args.use_v2) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path, args.use_v2) + dataset, num_classes = get_dataset(is_train=True, args=args) + dataset_test, _ = get_dataset(is_train=False, args=args) print("Creating data loaders") if args.distributed: