Skip to content

Commit

Permalink
Add sanitize_bounding_boxes kernel/functional (pytorch#8308)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Mar 15, 2024
1 parent d1f3a7b commit 53869eb
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 43 deletions.
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ Functionals

v2.functional.normalize
v2.functional.erase
v2.functional.sanitize_bounding_boxes
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample

Expand Down
99 changes: 81 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5675,18 +5675,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


class TestSanitizeBoundingBoxes:
@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128

def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
Expand All @@ -5706,18 +5695,31 @@ def test_transform(self, min_size, labels_getter, sample_type):
]

random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, is_valid_mask = zip(*boxes_and_validity)
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]

boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[0])
boxes, expected_valid_mask = zip(*boxes_and_validity)

boxes = tv_tensors.BoundingBoxes(
boxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W),
)

return boxes, expected_valid_mask

@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]

labels = torch.arange(boxes.shape[0])
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
Expand Down Expand Up @@ -5763,6 +5765,44 @@ def test_transform(self, min_size, labels_getter, sample_type):
# This works because we conveniently set labels to arange(num_boxes)
assert out_labels.tolist() == valid_indices

@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
def test_functional(self, input_type):
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
# redundancy with test_transform() in terms of correctness checks. But that's OK.

H, W, min_size = 256, 128, 10

boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)

if input_type is tv_tensors.BoundingBoxes:
format = canvas_size = None
else:
# just passing "XYXY" explicitly to make sure we support strings
format, canvas_size = "XYXY", boxes.canvas_size
boxes = boxes.as_subclass(torch.Tensor)

boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size)

assert_equal(valid, torch.tensor(expected_valid_mask))
assert type(valid) == torch.Tensor
assert boxes.shape[0] == sum(valid)
assert isinstance(boxes, input_type)

def test_kernel(self):
H, W, min_size = 256, 128, 10
boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)

format, canvas_size = boxes.format, boxes.canvas_size
boxes = boxes.as_subclass(torch.Tensor)

check_kernel(
F.sanitize_bounding_boxes,
input=boxes,
format=format,
canvas_size=canvas_size,
check_batched_vs_unbatched=False,
)

def test_no_label(self):
# Non-regression test for https://github.com/pytorch/vision/issues/7878

Expand All @@ -5776,7 +5816,7 @@ def test_no_label(self):
assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)

def test_errors(self):
def test_errors_transform(self):
good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
format=tv_tensors.BoundingBoxFormat.XYXY,
Expand All @@ -5799,3 +5839,26 @@ def test_errors(self):
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)

def test_errors_functional(self):

good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(20, 20),
)

with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10))

with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
F.sanitize_bounding_boxes(good_bbox.tolist())
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _extract_image_targets(
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxeses, Masks and Labels or OneHotLabels."
"BoundingBoxes, Masks and Labels or OneHotLabels."
)

targets = []
Expand Down
30 changes: 8 additions & 22 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import PIL.Image

Expand Down Expand Up @@ -369,28 +369,14 @@ def forward(self, *inputs: Any) -> Any:
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)

boxes = cast(
tv_tensors.BoundingBoxes,
F.convert_bounding_box_format(
boxes,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
),
valid = F._misc._get_sanitize_bounding_boxes_mask(
boxes,
format=boxes.format,
canvas_size=boxes.canvas_size,
min_size=self.min_size,
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.canvas_size
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)

params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
params = dict(valid=valid, labels=labels)
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]

return tree_unflatten(flat_outputs, spec)

Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
normalize,
normalize_image,
normalize_video,
sanitize_bounding_boxes,
to_dtype,
to_dtype_image,
to_dtype_video,
Expand Down
92 changes: 90 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional
from typing import List, Optional, Tuple

import PIL.Image
import torch
Expand All @@ -11,7 +11,9 @@

from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal
from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor


def normalize(
Expand Down Expand Up @@ -275,3 +277,89 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
# We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
return inpt.to(dtype)


def sanitize_bounding_boxes(
bounding_boxes: torch.Tensor,
format: Optional[tv_tensors.BoundingBoxFormat] = None,
canvas_size: Optional[Tuple[int, int]] = None,
min_size: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
This removes bounding boxes that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
If you want to be extra careful, you may call it after all transforms that
may modify bounding boxes but once at the end should be enough in most
cases.
Args:
bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized.
format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes.
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
canvas_size (tuple of int, optional): The canvas_size of the bounding boxes
(size of the corresponding image/video).
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
Returns:
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes.
"""
if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes):
if format is None or canvas_size is None:
raise ValueError(
"format and canvas_size cannot be None if bounding_boxes is a pure tensor. "
f"Got format={format} and canvas_size={canvas_size}."
"Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object."
)
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format.upper()]
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
)
bounding_boxes = bounding_boxes[valid]
else:
if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes):
raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.")
if format is not None or canvas_size is not None:
raise ValueError(
"format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. "
f"Got format={format} and canvas_size={canvas_size}. "
"Leave those to None or pass bouding_boxes as a pure tensor."
)
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
)
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)

return bounding_boxes, valid


def _get_sanitize_bounding_boxes_mask(
bounding_boxes: torch.Tensor,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
min_size: float = 1.0,
) -> torch.Tensor:

bounding_boxes = _convert_bounding_box_format(
bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format
)

image_h, image_w = canvas_size
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = canvas_size
valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w)
valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h)
return valid

0 comments on commit 53869eb

Please sign in to comment.