Skip to content

Commit

Permalink
fix prototype tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 17, 2023
1 parent 685e16e commit b629b9d
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
spatial_size = (11, 5)
canvas_size = (11, 5)

transform = transforms.FixedSizeCrop(size=crop_size)

flat_inputs = [
make_image(size=spatial_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape),
make_image(canvas_size, color_space="RGB"),
make_bounding_box(canvas_size, format=BoundingBoxFormat.XYXY, batch_dims=batch_shape),
]
params = transform._get_params(flat_inputs)

Expand Down Expand Up @@ -295,7 +295,7 @@ def test__transform(self, mocker, needs):

def test__transform_culling(self, mocker):
batch_size = 10
spatial_size = (10, 10)
canvas_size = (10, 10)

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
Expand All @@ -304,17 +304,15 @@ def test__transform_culling(self, mocker):
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=is_valid,
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
)
masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,))
bounding_boxes = make_bounding_box(canvas_size, format=BoundingBoxFormat.XYXY, batch_dims=(batch_size,))
masks = make_detection_mask(canvas_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
Expand All @@ -334,24 +332,22 @@ def test__transform_culling(self, mocker):

def test__transform_bounding_box_clamping(self, mocker):
batch_size = 3
spatial_size = (10, 10)
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_box = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
)
bounding_box = make_bounding_box(canvas_size, format=BoundingBoxFormat.XYXY, batch_dims=(batch_size,))
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")

transform = transforms.FixedSizeCrop((-1, -1))
Expand Down Expand Up @@ -496,27 +492,27 @@ def make_datapoints():

pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_mask(size, num_objects=num_objects, dtype=torch.long),
}

yield (pil_image, target)

tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
tensor_image = torch.Tensor(make_image(size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_mask(size, num_objects=num_objects, dtype=torch.long),
}

yield (tensor_image, target)

datapoint_image = make_image(size=size, color_space="RGB")
datapoint_image = make_image(size, color_space="RGB")
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_mask(size, num_objects=num_objects, dtype=torch.long),
}

yield (datapoint_image, target)
Expand Down

0 comments on commit b629b9d

Please sign in to comment.