Skip to content

Commit

Permalink
remove custom types defintions from datapoints module (#7814)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 9, 2023
1 parent 6b02079 commit 641fdd9
Show file tree
Hide file tree
Showing 21 changed files with 141 additions and 178 deletions.
6 changes: 3 additions & 3 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._datapoint import Datapoint
from ._image import Image
from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._video import Video

if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings
Expand Down
9 changes: 1 addition & 8 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union

import PIL.Image
import torch
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size


D = TypeVar("D", bound="Datapoint")
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]


class Datapoint(torch.Tensor):
Expand Down Expand Up @@ -132,7 +129,3 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]


_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor
6 changes: 0 additions & 6 deletions torchvision/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,3 @@ def __new__(

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()


_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor
6 changes: 0 additions & 6 deletions torchvision/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,3 @@ def __new__(

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()


_VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor
_TensorVideoType = Union[torch.Tensor, Video]
_TensorVideoTypeJIT = torch.Tensor
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def __init__(

def _copy_paste(
self,
image: datapoints._TensorImageType,
image: Union[torch.Tensor, datapoints.Image],
target: Dict[str, Any],
paste_image: datapoints._TensorImageType,
paste_image: Union[torch.Tensor, datapoints.Image],
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]:
) -> Tuple[torch.Tensor, Dict[str, Any]]:

paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
Expand Down Expand Up @@ -106,7 +106,7 @@ def _copy_paste(

def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]:
) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
Expand Down Expand Up @@ -137,7 +137,7 @@ def _extract_image_targets(
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[datapoints._TensorImageType],
output_images: List[torch.Tensor],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
Expand Down
8 changes: 2 additions & 6 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
)
self.dims = dims

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
Expand All @@ -63,9 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
)
self.dims = dims

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
Expand Down
24 changes: 14 additions & 10 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]


class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
Expand All @@ -35,7 +39,7 @@ def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)

Expand Down Expand Up @@ -68,20 +72,20 @@ def _flatten_and_extract_image_or_video(
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
image_or_video: ImageOrVideo,
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)

def _apply_image_or_video_transform(
self,
image: Union[datapoints._ImageType, datapoints._VideoType],
image: ImageOrVideo,
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill: Dict[Union[Type, str], _FillTypeJIT],
) -> ImageOrVideo:
fill_ = _get_fill(fill, type(image))

if transform_id == "Identity":
Expand Down Expand Up @@ -214,7 +218,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -394,7 +398,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -467,7 +471,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -550,7 +554,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params

def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
Expand Down
18 changes: 8 additions & 10 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._utils import _FillType

from ._transform import _RandomApplyTransform
from ._utils import (
Expand Down Expand Up @@ -311,9 +312,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)


ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]


class FiveCrop(Transform):
"""[BETA] Crop the image or video into four corners and the central crop.
Expand Down Expand Up @@ -459,7 +457,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):

def __init__(
self,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
Expand Down Expand Up @@ -592,7 +590,7 @@ def __init__(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
Expand Down Expand Up @@ -674,7 +672,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -812,7 +810,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -931,7 +929,7 @@ def __init__(
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__(p=p)

Expand Down Expand Up @@ -1033,7 +1031,7 @@ def __init__(
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)


Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_temporal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict

import torch
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform


Expand All @@ -25,5 +24,5 @@ def __init__(self, num_samples: int):
super().__init__()
self.num_samples = num_samples

def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.uniform_temporal_subsample(inpt, self.num_samples)
5 changes: 2 additions & 3 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import torch

from torchvision import datapoints
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
Expand Down Expand Up @@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")


def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
Expand Down
6 changes: 2 additions & 4 deletions torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import PIL.Image

import torch
Expand All @@ -12,14 +10,14 @@

@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
inpt: torch.Tensor,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
) -> torch.Tensor:
if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)

Expand Down
Loading

0 comments on commit 641fdd9

Please sign in to comment.