Skip to content

Commit

Permalink
[proto] Improvements for functional API and tests (#6187)
Browse files Browse the repository at this point in the history
* Added base tests for rotate_image_tensor

* Updated resize_image_tensor API and tests and fixed a bug with max_size

* Refactored and modified private api for resize functional op

* Fixed failures

* More updates

* Updated proto functional op: resize_image_*

* Added max_size arg to resize_bounding_box and updated basic tests

* Update functional.py

* Reverted fill/center order for rotate
Other nits
  • Loading branch information
vfdev-5 authored Jun 23, 2022
1 parent aeafa91 commit 6155808
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
56 changes: 49 additions & 7 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,32 +201,58 @@ def horizontal_flip_bounding_box():

@register_kernel_info_from_sample_inputs_fn
def resize_image_tensor():
for image, interpolation in itertools.product(
for image, interpolation, max_size, antialias in itertools.product(
make_images(),
[
F.InterpolationMode.BILINEAR,
F.InterpolationMode.NEAREST,
],
[F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation
[None, 34], # max_size
[False, True], # antialias
):

if antialias and interpolation == F.InterpolationMode.NEAREST:
continue

height, width = image.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(image, size=size, interpolation=interpolation)
if max_size is not None:
size = [size[0]]
yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


@register_kernel_info_from_sample_inputs_fn
def resize_bounding_box():
for bounding_box in make_bounding_boxes():
for bounding_box, max_size in itertools.product(
make_bounding_boxes(),
[None, 34], # max_size
):
height, width = bounding_box.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
if max_size is not None:
size = [size[0]]
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)


@register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask():
for mask, max_size in itertools.product(
make_segmentation_masks(),
[None, 34], # max_size
):
height, width = mask.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
if max_size is not None:
size = [size[0]]
yield SampleInput(mask, size=size, max_size=max_size)


@register_kernel_info_from_sample_inputs_fn
def affine_image_tensor():
for image, angle, translate, scale, shear in itertools.product(
Expand Down Expand Up @@ -284,6 +310,22 @@ def affine_segmentation_mask():
)


@register_kernel_info_from_sample_inputs_fn
def rotate_image_tensor():
for image, angle, expand, center, fill in itertools.product(
make_images(extra_dims=((), (4,))),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
[None, [128]], # fill
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue

yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill)


@register_kernel_info_from_sample_inputs_fn
def rotate_bounding_box():
for bounding_box, angle, expand, center in itertools.product(
Expand Down
27 changes: 17 additions & 10 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode
from torchvision.transforms.functional import (
pil_modes_mapping,
_get_inverse_affine_matrix,
InterpolationMode,
_compute_output_size,
)

from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil

Expand Down Expand Up @@ -42,14 +47,12 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
# TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
new_height, new_width = size
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))
Expand All @@ -61,8 +64,11 @@ def resize_image_pil(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
# TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
if isinstance(size, int):
size = [size, size]
# Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]"
size: List[int] = list(size)
size = _compute_output_size(img.size[::-1], size=size, max_size=max_size)
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])


Expand All @@ -72,10 +78,11 @@ def resize_segmentation_mask(
return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)


# TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
def resize_bounding_box(
bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None
) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
new_height, new_width = _compute_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)

Expand Down

0 comments on commit 6155808

Please sign in to comment.