diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f79a9844a12..b29c22ee1f6 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -234,7 +234,6 @@ Conversion v2.PILToTensor v2.ToImageTensor ConvertImageDtype - v2.ConvertDtype v2.ConvertImageDtype v2.ToDtype v2.ConvertBoundingBoxFormat diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index 5d8d22dce83..951af514b51 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -29,7 +29,7 @@ def show(sample): image, target = sample if isinstance(image, PIL.Image.Image): image = F.to_image_tensor(image) - image = F.convert_dtype(image, torch.uint8) + image = F.to_dtype(image, torch.uint8, scale=True) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) fig, ax = plt.subplots() diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 313faacdb7c..07c98a67ca2 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,4 +1,3 @@ -import copy import os import torch @@ -7,25 +6,6 @@ import transforms as T from pycocotools import mask as coco_mask from pycocotools.coco import COCO -from torchvision.datasets import wrap_dataset_for_transforms_v2 - - -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target def convert_coco_poly_to_mask(segmentations, height, width): @@ -219,7 +199,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances", use_v2=False): +def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -232,10 +212,15 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): ann_file = os.path.join(root, ann_file) if use_v2: + from torchvision.datasets import wrap_dataset_for_transforms_v2 + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) - # TODO: need to update target_keys to handle masks for segmentation! - dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) + target_keys = ["boxes", "labels", "image_id"] + if with_masks: + target_keys += ["masks"] + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) else: + # TODO: handle with_masks for V1? t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) @@ -249,9 +234,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset - - -def get_coco_kp(root, image_set, transforms, use_v2=False): - if use_v2: - raise ValueError("KeyPoints aren't supported by transforms V2 yet.") - return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/train.py b/references/detection/train.py index db86f33aaa9..892ffbbbc1c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -28,7 +28,7 @@ import torchvision.models.detection import torchvision.models.detection.mask_rcnn import utils -from coco_utils import get_coco, get_coco_kp +from coco_utils import get_coco from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode @@ -42,10 +42,16 @@ def copypaste_collate_fn(batch): def get_dataset(is_train, args): image_set = "train" if is_train else "val" - paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[args.dataset] - - ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) + num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + with_masks = "mask" in args.model + ds = get_coco( + root=args.data_path, + image_set=image_set, + transforms=get_transform(is_train, args), + mode=mode, + use_v2=args.use_v2, + with_masks=with_masks, + ) return ds, num_classes @@ -68,7 +74,12 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name") + parser.add_argument( + "--dataset", + default="coco", + type=str, + help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection", + ) parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -171,6 +182,12 @@ def get_args_parser(add_help=True): def main(args): if args.backend.lower() == "datapoint" and not args.use_v2: raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.dataset not in ("coco", "coco_kp"): + raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") + if "keypoint" in args.model and args.dataset != "coco_kp": + raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") + if args.dataset == "coco_kp" and args.use_v2: + raise ValueError("KeyPoint detection doesn't support V2 transforms yet") if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index e02434012f1..6a15dbefb52 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -68,11 +68,6 @@ def _has_valid_annotation(anno): # if more than 1k pixels occupied in the image return sum(obj["area"] for obj in anno) > 1000 - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) - ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -86,7 +81,7 @@ def _has_valid_annotation(anno): return dataset -def get_coco(root, image_set, transforms): +def get_coco(root, image_set, transforms, use_v2=False): PATHS = { "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), @@ -94,13 +89,24 @@ def get_coco(root, image_set, transforms): } CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] - transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + # The 2 "Compose" below achieve the same thing: converting coco detection + # samples into segmentation-compatible samples. They just do it with + # slightly different implementations. We could refactor and unify, but + # keeping them separate helps keeping the v2 version clean + if use_v2: + import v2_extras + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) + else: + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..abb70d8d0db 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,106 @@ +from collections import defaultdict + import torch -import transforms as T + + +def get_modules(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.datapoints + import torchvision.transforms.v2 + import v2_extras + + return torchvision.transforms.v2, torchvision.datapoints, v2_extras + else: + import transforms + + return transforms, None, None class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) + def __init__( + self, + *, + base_size, + crop_size, + hflip_prob=0.5, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + backend="pil", + use_v2=False, + ): + T, datapoints, v2_extras = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))] - trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( - [ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), + transforms += [T.RandomHorizontalFlip(hflip_prob)] + + if use_v2: + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))] + + transforms += [T.RandomCrop(crop_size)] + + if backend == "pil": + transforms += [T.PILToTensor()] + + if use_v2: + img_type = datapoints.Image if backend == "datapoint" else torch.Tensor + transforms += [ + T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True) ] - ) - self.transforms = T.Compose(trans) + else: + # No need to explicitly convert masks as they're magically int64 already + transforms += [T.ConvertImageDtype(torch.float)] + + transforms += [T.Normalize(mean=mean, std=std)] + + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) class SegmentationPresetEval: - def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( - [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) + def __init__( + self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False + ): + T, _, _ = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "tensor": + transforms += [T.PILToTensor()] + elif backend == "datapoint": + transforms += [T.ToImageTensor()] + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + if use_v2: + transforms += [T.Resize(size=(base_size, base_size))] + else: + transforms += [T.RandomResize(min_size=base_size, max_size=base_size)] + + if backend == "pil": + # Note: we could just convert to pure tensors even in v2? + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..7ca4bd1c592 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -14,24 +14,30 @@ from torchvision.transforms import functional as F, InterpolationMode -def get_dataset(dir_path, name, image_set, transform): +def get_dataset(args, is_train): def sbd(*args, **kwargs): + kwargs.pop("use_v2") return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + def voc(*args, **kwargs): + kwargs.pop("use_v2") + return torchvision.datasets.VOCSegmentation(*args, **kwargs) + paths = { - "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), - "voc_aug": (dir_path, sbd, 21), - "coco": (dir_path, get_coco, 21), + "voc": (args.data_path, voc, 21), + "voc_aug": (args.data_path, sbd, 21), + "coco": (args.data_path, get_coco, 21), } - p, ds_fn, num_classes = paths[name] + p, ds_fn, num_classes = paths[args.dataset] - ds = ds_fn(p, image_set=image_set, transforms=transform) + image_set = "train" if is_train else "val" + ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) return ds, num_classes -def get_transform(train, args): - if train: - return presets.SegmentationPresetTrain(base_size=520, crop_size=480) +def get_transform(is_train, args): + if is_train: + return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() @@ -44,7 +50,7 @@ def preprocessing(img, target): return preprocessing else: - return presets.SegmentationPresetEval(base_size=520) + return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) def criterion(inputs, target): @@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): + if args.backend.lower() != "pil" and not args.use_v2: + # TODO: Support tensor backend in V1? + raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.") + if args.use_v2 and args.dataset != "coco": + raise ValueError("v2 is only support supported for coco dataset for now.") + if args.output_dir: utils.mkdir(args.output_dir) @@ -134,8 +146,8 @@ def main(args): else: torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) + dataset, num_classes = get_dataset(args, is_train=True) + dataset_test, _ = get_dataset(args, is_train=False) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -307,6 +319,8 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..2b3e79b1461 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None): def __call__(self, image, target): size = random.randint(self.min_size, self.max_size) - image = F.resize(image, size) + image = F.resize(image, size, antialias=True) target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) return image, target diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index 4ea24db83ed..cb200f23d76 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -267,9 +267,9 @@ def init_distributed_mode(args): args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py new file mode 100644 index 00000000000..c69827c22e7 --- /dev/null +++ b/references/segmentation/v2_extras.py @@ -0,0 +1,83 @@ +"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1.""" +import torch +from torchvision import datapoints +from torchvision.transforms import v2 + + +class PadIfSmaller(v2.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = v2._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = v2.utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = v2._utils._convert_fill_arg(fill) + + return v2.functional.pad(inpt, padding=params["padding"], fill=fill) + + +class CocoDetectionToVOCSegmentation(v2.Transform): + """Turn samples from datasets.CocoDetection into the same format as VOCSegmentation. + + This is achieved in two steps: + + 1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately, + the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not + present in VOC are dropped and replaced by background. + 2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual + mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where + the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation + mask while pixels that belong to multiple detection masks are marked as invalid. + """ + + COCO_TO_VOC_LABEL_MAP = dict( + zip( + [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72], + range(21), + ) + ) + INVALID_VALUE = 255 + + def _coco_detection_masks_to_voc_segmentation_mask(self, target): + if "masks" not in target: + return None + + instance_masks, instance_labels_coco = target["masks"], target["labels"] + + valid_labels_voc = [ + (idx, label_voc) + for idx, label_coco in enumerate(instance_labels_coco.tolist()) + if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None + ] + + if not valid_labels_voc: + return None + + valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc) + + instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8) + instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8) + + # Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as + # there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step. + segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0) + segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE + + return segmentation_mask + + def forward(self, image, target): + segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) + if segmentation_mask is None: + segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) + + return image, datapoints.Mask(segmentation_mask) diff --git a/test/common_utils.py b/test/common_utils.py index 72ecf104301..af8f5783263 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -27,7 +27,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import datapoints, io from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor +from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -503,12 +503,13 @@ def make_image( device="cpu", memory_format=torch.contiguous_format, ): + dtype = dtype or torch.uint8 max_value = get_max_value(dtype) data = torch.testing.make_tensor( (*batch_dims, get_num_channels(color_space), *size), low=0, high=max_value, - dtype=dtype or torch.uint8, + dtype=dtype, device=device, memory_format=memory_format, ) @@ -601,7 +602,7 @@ def fn(shape, dtype, device, memory_format): image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) else: image_tensor = image_tensor.to(device=device) - image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) + image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True) return datapoints.Image(image_tensor) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f17b63757b8..0e311fd65de 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1,7 +1,6 @@ import itertools import pathlib import random -import re import textwrap import warnings from collections import defaultdict @@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device): continue elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): # normalize doesn't support integer images - value = F.convert_dtype(value, torch.float32) + value = F.to_dtype(value, torch.float32, scale=True) adapted_input[key] = value return adapted_input @@ -146,7 +145,7 @@ class TestSmoke: (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), (transforms.ClampBoundingBox(), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), - (transforms.ConvertDtype(), None), + (transforms.ConvertImageDtype(), None), (transforms.GaussianBlur(kernel_size=3), None), ( transforms.LinearTransformation( @@ -1326,61 +1325,6 @@ def test__transform(self, mocker): ) -class TestToDtype: - @pytest.mark.parametrize( - ("dtype", "expected_dtypes"), - [ - ( - torch.float64, - { - datapoints.Video: torch.float64, - datapoints.Image: torch.float64, - datapoints.BoundingBox: torch.float64, - }, - ), - ( - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - ), - ], - ) - def test_call(self, dtype, expected_dtypes): - sample = dict( - video=make_video(dtype=torch.int64), - image=make_image(dtype=torch.uint8), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), - str="str", - int=0, - ) - - transform = transforms.ToDtype(dtype) - transformed_sample = transform(sample) - - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] - - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) - - if isinstance(value, torch.Tensor): - assert transformed_value.dtype is expected_dtypes[value_type] - else: - assert transformed_value is value - - @pytest.mark.filterwarnings("error") - def test_plain_tensor_call(self): - tensor = torch.empty((), dtype=torch.float32) - transform = transforms.ToDtype({torch.Tensor: torch.float64}) - - assert transform(tensor).dtype is torch.float64 - - @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) - def test_plain_tensor_warning(self, other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) - - class TestUniformTemporalSubsample: @pytest.mark.parametrize( "inpt", diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 102afdb37a9..9b7886f47da 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -191,7 +191,7 @@ def __init__( closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( - v2_transforms.ConvertDtype, + v2_transforms.ConvertImageDtype, legacy_transforms.ConvertImageDtype, [ ArgsKwargs(torch.float16), diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 465cc227107..47ea0069474 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -283,12 +283,12 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs): adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) actual = info.kernel( - F.convert_dtype_image_tensor(input, dtype=torch.float32), + F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True), *adapted_other_args, **adapted_kwargs, ) - expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32) + expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) assert_close( actual, @@ -538,7 +538,6 @@ def test_bounding_box_format_consistency(self, info, args_kwargs): (F.get_image_num_channels, F.get_num_channels), (F.to_pil_image, F.to_image_pil), (F.elastic_transform, F.elastic), - (F.convert_image_dtype, F.convert_dtype_image_tensor), (F.to_grayscale, F.rgb_to_grayscale), ] ], @@ -547,24 +546,6 @@ def test_alias(alias, target): assert alias is target -@pytest.mark.parametrize( - ("info", "args_kwargs"), - make_info_args_kwargs_params( - KERNEL_INFOS_MAP[F.convert_dtype_image_tensor], - args_kwargs_fn=lambda info: info.sample_inputs_fn(), - ), -) -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): - (input, *other_args), kwargs = args_kwargs.load(device) - dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) - - output = info.kernel(input, dtype) - - assert output.dtype == dtype - assert output.device == input.device - - @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("num_channels", [1, 3]) def test_normalize_image_tensor_stats(device, num_channels): diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 42825e71c65..0ec3c5f01ee 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1,4 +1,5 @@ import contextlib +import decimal import inspect import math import re @@ -31,6 +32,7 @@ from torch import nn from torch.testing import assert_close +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader, default_collate from torchvision import datapoints @@ -71,11 +73,12 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs): @cache -def _script(fn): +def _script(obj): try: - return torch.jit.script(fn) + return torch.jit.script(obj) except Exception as error: - raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error + name = getattr(obj, "__name__", obj.__class__.__name__) + raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): @@ -132,6 +135,7 @@ def check_kernel( check_cuda_vs_cpu=True, check_scripted_vs_eager=True, check_batched_vs_unbatched=True, + expect_same_dtype=True, **kwargs, ): initial_input_version = input._version @@ -144,7 +148,8 @@ def check_kernel( # check that no inplace operation happened assert input._version == initial_input_version - assert output.dtype == input.dtype + if expect_same_dtype: + assert output.dtype == input.dtype assert output.device == input.device if check_cuda_vs_cpu: @@ -281,7 +286,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type): def _check_transform_v1_compatibility(transform, input): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method, is scriptable, and the scripted version can be called without error.""" - if not hasattr(transform, "_v1_transform_cls"): + if transform._v1_transform_cls is None: return if type(input) is not torch.Tensor: @@ -1704,6 +1709,196 @@ def call_transform(): assert output[1] is label +class TestToDtype: + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.to_dtype_image_tensor, make_image_tensor), + (F.to_dtype_image_tensor, make_image), + (F.to_dtype_video, make_video), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale): + check_kernel( + kernel, + make_input(dtype=input_dtype, device=device), + expect_same_dtype=input_dtype is output_dtype, + dtype=output_dtype, + scale=scale, + ) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.to_dtype_image_tensor, make_image_tensor), + (F.to_dtype_image_tensor, make_image), + (F.to_dtype_video, make_video), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale): + check_dispatcher( + F.to_dtype, + kernel, + make_input(dtype=input_dtype, device=device), + # TODO: we could leave check_dispatch to True but it currently fails + # in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. + # We should be able to put this back if we change the dispatch + # mechanism e.g. via https://github.com/pytorch/vision/pull/7733 + check_dispatch=False, + dtype=output_dtype, + scale=scale, + ) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_bounding_box, make_segmentation_mask, make_video], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + @pytest.mark.parametrize("as_dict", (True, False)) + def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict): + input = make_input(dtype=input_dtype, device=device) + if as_dict: + output_dtype = {type(input): output_dtype} + check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale) + + def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): + input_dtype = image.dtype + output_dtype = dtype + + if not scale: + return image.to(dtype) + + if output_dtype == input_dtype: + return image + + def fn(value): + if input_dtype.is_floating_point: + if output_dtype.is_floating_point: + return value + else: + return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max) + else: + input_max_value = torch.iinfo(input_dtype).max + + if output_dtype.is_floating_point: + return float(decimal.Decimal(value) / input_max_value) + else: + output_max_value = torch.iinfo(output_dtype).max + + if input_max_value > output_max_value: + factor = (input_max_value + 1) // (output_max_value + 1) + return value / factor + else: + factor = (output_max_value + 1) // (input_max_value + 1) + return value * factor + + return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device) + + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_image_correctness(self, input_dtype, output_dtype, device, scale): + if input_dtype.is_floating_point and output_dtype == torch.int64: + pytest.xfail("float to int64 conversion is not supported") + + input = make_image(dtype=input_dtype, device=device) + + out = F.to_dtype(input, dtype=output_dtype, scale=scale) + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + + if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) + + def was_scaled(self, inpt): + # this assumes the target dtype is float + return inpt.max() <= 1 + + def make_inpt_with_bbox_and_mask(self, make_input): + H, W = 10, 10 + inpt_dtype = torch.uint8 + bbox_dtype = torch.float32 + mask_dtype = torch.bool + sample = { + "inpt": make_input(size=(H, W), dtype=inpt_dtype), + "bbox": make_bounding_box(size=(H, W), dtype=bbox_dtype), + "mask": make_detection_mask(size=(H, W), dtype=mask_dtype), + } + + return sample, inpt_dtype, bbox_dtype, mask_dtype + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + @pytest.mark.parametrize("scale", (True, False)) + def test_dtype_not_a_dict(self, make_input, scale): + # assert only inpt gets transformed when dtype isn't a dict + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample) + + assert out["inpt"].dtype != inpt_dtype + assert out["inpt"].dtype == torch.float32 + if scale: + assert self.was_scaled(out["inpt"]) + else: + assert not self.was_scaled(out["inpt"]) + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype == mask_dtype + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_others_catch_all_and_none(self, make_input): + # make sure "others" works as a catch-all and that None means no conversion + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype(dtype={datapoints.Mask: torch.int64, "others": None})(sample) + assert out["inpt"].dtype == inpt_dtype + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype != mask_dtype + assert out["mask"].dtype == torch.int64 + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_typical_use_case(self, make_input): + # Typical use-case: want to convert dtype and scale for inpt and just dtype for masks. + # This just makes sure we now have a decent API for this + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype( + dtype={type(sample["inpt"]): torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True + )(sample) + assert out["inpt"].dtype != inpt_dtype + assert out["inpt"].dtype == torch.float32 + assert self.was_scaled(out["inpt"]) + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype != mask_dtype + assert out["mask"].dtype == torch.int64 + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_errors_warnings(self, make_input): + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + + with pytest.raises(ValueError, match="No dtype was specified for"): + out = transforms.ToDtype(dtype={datapoints.Mask: torch.float32})(sample) + with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")): + transforms.ToDtype(dtype={torch.Tensor: torch.float32, datapoints.Image: torch.float32}) + with pytest.warns(UserWarning, match="no scaling will be done"): + out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample) + assert out["inpt"].dtype == inpt_dtype + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype == mask_dtype + + class TestCutMixMixUp: class DummyDataset: def __init__(self, size, num_classes): diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 6f61526f382..57b905035b4 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -364,16 +364,6 @@ def fill_sequence_needs_broadcast(args_kwargs): xfail_jit_python_scalar_arg("std"), ], ), - DispatcherInfo( - F.convert_dtype, - kernels={ - datapoints.Image: F.convert_dtype_image_tensor, - datapoints.Video: F.convert_dtype_video, - }, - test_marks=[ - skip_dispatch_datapoint, - ], - ), DispatcherInfo( F.uniform_temporal_subsample, kernels={ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index dc04fbfc7a9..036b3e4d360 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1,4 +1,3 @@ -import decimal import functools import itertools @@ -27,7 +26,6 @@ mark_framework_limitation, TestMark, ) -from torch.utils._pytree import tree_map from torchvision import datapoints from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding @@ -1566,7 +1564,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): def wrapper(input_tensor, *other_args, **kwargs): output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) return type(output)( - F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype) + F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output ) @@ -1667,125 +1665,6 @@ def sample_inputs_normalize_video(): ) -def sample_inputs_convert_dtype_image_tensor(): - for input_dtype, output_dtype in itertools.product( - [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 - ): - if input_dtype.is_floating_point and output_dtype == torch.int64: - # conversion cannot be performed safely - continue - - for image_loader in make_image_loaders( - sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[input_dtype] - ): - yield ArgsKwargs(image_loader, dtype=output_dtype) - - -def reference_convert_dtype_image_tensor(image, dtype=torch.float): - input_dtype = image.dtype - output_dtype = dtype - - if output_dtype == input_dtype: - return image - - def fn(value): - if input_dtype.is_floating_point: - if output_dtype.is_floating_point: - return value - else: - return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) - else: - input_max_value = torch.iinfo(input_dtype).max - - if output_dtype.is_floating_point: - return float(decimal.Decimal(value) / input_max_value) - else: - output_max_value = torch.iinfo(output_dtype).max - - if input_max_value > output_max_value: - factor = (input_max_value + 1) // (output_max_value + 1) - return value // factor - else: - factor = (output_max_value + 1) // (input_max_value + 1) - return value * factor - - return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) - - -def reference_inputs_convert_dtype_image_tensor(): - for input_dtype, output_dtype in itertools.product( - [ - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - ], - repeat=2, - ): - if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or ( - input_dtype == torch.float64 and output_dtype == torch.int64 - ): - continue - - if input_dtype.is_floating_point: - data = [0.0, 0.5, 1.0] - else: - max_value = torch.iinfo(input_dtype).max - data = [0, max_value // 2, max_value] - image = torch.tensor(data, dtype=input_dtype) - - yield ArgsKwargs(image, dtype=output_dtype) - - -def sample_inputs_convert_dtype_video(): - for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): - yield ArgsKwargs(video_loader) - - -skip_dtype_consistency = TestMark( - ("TestKernels", "test_dtype_and_device_consistency"), - pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), - condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), -) - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.convert_dtype_image_tensor, - sample_inputs_fn=sample_inputs_convert_dtype_image_tensor, - reference_fn=reference_convert_dtype_image_tensor, - reference_inputs_fn=reference_inputs_convert_dtype_image_tensor, - test_marks=[ - skip_dtype_consistency, - TestMark( - ("TestKernels", "test_against_reference"), - pytest.mark.xfail(reason="Conversion overflows"), - condition=lambda args_kwargs: ( - args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} - and not args_kwargs.kwargs["dtype"].is_floating_point - ) - or ( - args_kwargs.args[0].dtype in {torch.int32, torch.int64} - and args_kwargs.kwargs["dtype"] == torch.float16 - ), - ), - ], - ), - KernelInfo( - F.convert_dtype_video, - sample_inputs_fn=sample_inputs_convert_dtype_video, - test_marks=[ - skip_dtype_consistency, - ], - ), - ] -) - - def sample_inputs_uniform_temporal_subsample_video(): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]): yield ArgsKwargs(video_loader, num_samples=2) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index a346f6b4c96..7e1080e6982 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -39,7 +39,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype +from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index b7e2a42259f..5299e318f01 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -31,10 +31,13 @@ def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> da return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] -class ConvertDtype(Transform): - """[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly. +class ConvertImageDtype(Transform): + """[BETA] Convert input image to the given ``dtype`` and scale the values accordingly. - .. v2betastatus:: ConvertDtype transform + .. v2betastatus:: ConvertImageDtype transform + + .. warning:: + Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`. This function does not support PIL Image. @@ -55,21 +58,14 @@ class ConvertDtype(Transform): _v1_transform_cls = _transforms.ConvertImageDtype - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + _transformed_types = (is_simple_tensor, datapoints.Image) def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]: - return F.convert_dtype(inpt, self.dtype) - - -# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -ConvertImageDtype = ConvertDtype + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.to_dtype(inpt, dtype=self.dtype, scale=True) class ClampBoundingBox(Transform): diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index fc980850fec..b22c61727b3 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,7 +9,7 @@ from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform -from ._utils import _get_defaultdict, _parse_labels_getter, _setup_float_or_seq, _setup_size +from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from .utils import has_any, is_simple_tensor, query_bounding_box @@ -223,36 +223,76 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToDtype(Transform): - """[BETA] Converts the input to a specific dtype - this does not scale values. + """[BETA] Converts the input to a specific dtype, optionally scaling the values for images or videos. .. v2betastatus:: ToDtype transform + .. note:: + ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``. + Args: dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to. + If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted + to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`. A dict can be passed to specify per-datapoint conversions, e.g. - ``dtype={datapoints.Image: torch.float32, datapoints.Video: - torch.float64}``. + ``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others" + key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion. + scale (bool, optional): Whether to scale the values for images or videos. Default: ``False``. """ _transformed_types = (torch.Tensor,) - def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: + def __init__( + self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False + ) -> None: super().__init__() - if not isinstance(dtype, dict): - dtype = _get_defaultdict(dtype) - if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + + if not isinstance(dtype, (dict, torch.dtype)): + raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead") + + if ( + isinstance(dtype, dict) + and torch.Tensor in dtype + and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]) + ): warnings.warn( "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "in case a `datapoints.Image` or `datapoints.Video` is present in the input." ) self.dtype = dtype + self.scale = scale def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - dtype = self.dtype[type(inpt)] + if isinstance(self.dtype, torch.dtype): + # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype + # is a simple torch.dtype + if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt + + dtype: Optional[torch.dtype] = self.dtype + elif type(inpt) in self.dtype: + dtype = self.dtype[type(inpt)] + elif "others" in self.dtype: + dtype = self.dtype["others"] + else: + raise ValueError( + f"No dtype was specified for type {type(inpt)}. " + "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. " + "If you're passing a dict as dtype, " + 'you can use "others" as a catch-all key ' + 'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' + ) + + supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) if dtype is None: + if self.scale and supports_scaling: + warnings.warn( + "scale was set to True but no dtype was specified for images or videos: no scaling will be done." + ) return inpt - return inpt.to(dtype=dtype) + + return F.to_dtype(inpt, dtype=dtype, scale=self.scale) class SanitizeBoundingBox(Transform): diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index b4803f4f1b9..4617d1af638 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -5,10 +5,10 @@ from ._meta import ( clamp_bounding_box, convert_format_bounding_box, - convert_dtype_image_tensor, - convert_dtype, - convert_dtype_video, convert_image_dtype, + to_dtype, + to_dtype_image_tensor, + to_dtype_video, get_dimensions_image_tensor, get_dimensions_image_pil, get_dimensions, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 13417e4a990..c2ee5611224 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -9,7 +9,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _num_value_bits, convert_dtype_image_tensor +from ._meta import _num_value_bits, to_dtype_image_tensor from ._utils import is_simple_tensor @@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten return image orig_dtype = image.dtype - image = convert_dtype_image_tensor(image, torch.float32) + image = to_dtype_image_tensor(image, torch.float32, scale=True) image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) @@ -359,7 +359,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image = torch.stack((h, s, v), dim=-3) image_hue_adj = _hsv_to_rgb(image) - return convert_dtype_image_tensor(image_hue_adj, orig_dtype) + return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) adjust_hue_image_pil = _FP.adjust_hue @@ -393,7 +393,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). # Since the gamma is non-negative, the output remains at [0, 1] scale. if not torch.is_floating_point(image): - output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) + output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma) else: output = image.pow(gamma) @@ -402,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 # of the output can go beyond [0, 1]. output = output.mul_(gain).clamp_(0.0, 1.0) - return convert_dtype_image_tensor(output, image.dtype) + return to_dtype_image_tensor(output, image.dtype, scale=True) adjust_gamma_image_pil = _FP.adjust_gamma @@ -565,7 +565,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is # by far the most common, we choose it as base. output_dtype = image.dtype - image = convert_dtype_image_tensor(image, torch.uint8) + image = to_dtype_image_tensor(image, torch.uint8, scale=True) # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # corresponds to adding 1 to index 127 in the histogram. @@ -616,7 +616,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) output = torch.where(valid_equalization, equalized_image, image) - return convert_dtype_image_tensor(output, output_dtype) + return to_dtype_image_tensor(output, output_dtype, scale=True) equalize_image_pil = _FP.equalize diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 8ffa3966195..5d0c072d26d 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -296,9 +296,12 @@ def _num_value_bits(dtype: torch.dtype) -> int: raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") -def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: +def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + if image.dtype == dtype: return image + elif not scale: + return image.to(dtype) float_input = image.is_floating_point() if torch.jit.is_scripting(): @@ -345,30 +348,28 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) -# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -convert_image_dtype = convert_dtype_image_tensor +# We encourage users to use to_dtype() instead but we keep this for BC +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + return to_dtype_image_tensor(image, dtype=dtype, scale=True) -def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: - return convert_dtype_image_tensor(video, dtype) +def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + return to_dtype_image_tensor(video, dtype, scale=scale) -def convert_dtype( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float -) -> torch.Tensor: +def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if not torch.jit.is_scripting(): - _log_api_usage_once(convert_dtype) + _log_api_usage_once(to_dtype) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return convert_dtype_image_tensor(inpt, dtype) + return to_dtype_image_tensor(inpt, dtype, scale=scale) elif isinstance(inpt, datapoints.Image): - output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) + output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale) return datapoints.Image.wrap_like(inpt, output) elif isinstance(inpt, datapoints.Video): - output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) + output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale) return datapoints.Video.wrap_like(inpt, output) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.to(dtype) else: - raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." - ) + raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")