diff --git a/references/classification/presets.py b/references/classification/presets.py index 0f2c914be7e..9970ee57730 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,8 +1,19 @@ import torch -from torchvision.transforms import autoaugment, transforms from torchvision.transforms.functional import InterpolationMode +def get_module(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.transforms.v2 + + return torchvision.transforms.v2 + else: + import torchvision.transforms + + return torchvision.transforms + + class ClassificationPresetTrain: def __init__( self, @@ -17,41 +28,44 @@ def __init__( augmix_severity=3, random_erase_prob=0.0, backend="pil", + use_v2=False, ): - trans = [] + module = get_module(use_v2) + + transforms = [] backend = backend.lower() if backend == "tensor": - trans.append(transforms.PILToTensor()) + transforms.append(module.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") - trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) + transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms.append(module.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + transforms.append(module.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) + transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity)) else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = module.AutoAugmentPolicy(auto_augment_policy) + transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation)) if backend == "pil": - trans.append(transforms.PILToTensor()) + transforms.append(module.PILToTensor()) - trans.extend( + transforms.extend( [ - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), + module.ConvertImageDtype(torch.float), + module.Normalize(mean=mean, std=std), ] ) if random_erase_prob > 0: - trans.append(transforms.RandomErasing(p=random_erase_prob)) + transforms.append(module.RandomErasing(p=random_erase_prob)) - self.transforms = transforms.Compose(trans) + self.transforms = module.Compose(transforms) def __call__(self, img): return self.transforms(img) @@ -67,28 +81,30 @@ def __init__( std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, backend="pil", + use_v2=False, ): - trans = [] + module = get_module(use_v2) + transforms = [] backend = backend.lower() if backend == "tensor": - trans.append(transforms.PILToTensor()) + transforms.append(module.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") - trans += [ - transforms.Resize(resize_size, interpolation=interpolation, antialias=True), - transforms.CenterCrop(crop_size), + transforms += [ + module.Resize(resize_size, interpolation=interpolation, antialias=True), + module.CenterCrop(crop_size), ] if backend == "pil": - trans.append(transforms.PILToTensor()) + transforms.append(module.PILToTensor()) - trans += [ - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), + transforms += [ + module.ConvertImageDtype(torch.float), + module.Normalize(mean=mean, std=std), ] - self.transforms = transforms.Compose(trans) + self.transforms = module.Compose(transforms) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 0c1a301453d..e5347631961 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -145,6 +145,7 @@ def load_data(traindir, valdir, args): ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, backend=args.backend, + use_v2=args.use_v2, ), ) if args.cache_dataset: @@ -172,6 +173,7 @@ def load_data(traindir, valdir, args): resize_size=val_resize_size, interpolation=interpolation, backend=args.backend, + use_v2=args.use_v2, ) dataset_test = torchvision.datasets.ImageFolder( @@ -516,6 +518,7 @@ def get_args_parser(add_help=True): ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") return parser