Skip to content

Commit

Permalink
Fix test + address commnet
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 31, 2023
1 parent 0e135b8 commit e3f6f54
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
Expand All @@ -20,7 +20,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = v2._geometry._get_fill(self.fill, type(inpt))
fill = v2._utils._get_fill(self.fill, type(inpt))
fill = v2._utils._convert_fill_arg(fill)

return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_spatial_size
from torchvision.transforms.v2._utils import _get_fill

DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])

Expand Down Expand Up @@ -1180,7 +1181,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = _get_fill(self.fill, type(inpt))
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)


Expand Down

0 comments on commit e3f6f54

Please sign in to comment.