Skip to content

Commit

Permalink
Add transparency support to webp decoder (#8610)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 27, 2024
1 parent 4c5ae78 commit bdf354e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 15 deletions.
27 changes: 26 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_decode_gif_webp_errors(decode_fun):
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
expected_match = "WebPGetFeatures failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)

Expand All @@ -891,6 +891,31 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


# This test is skipped because it requires webp images that we're not including
# within the repo. The test images were downloaded from the different pages of
# https://developers.google.com/speed/webp/gallery
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
@pytest.mark.skip(reason="Need to download test images first")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
encoded_bytes = read_file(filename)
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes, mode=mode)
assert img[None].is_contiguous(memory_format=torch.channels_last)

pil_img = Image.open(filename).convert(pil_mode)
from_pil = F.pil_to_tensor(pil_img)
assert_equal(img, from_pil)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ torch::Tensor decode_image(
TORCH_CHECK(data.numel() >= 15, err_msg);
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
return decode_webp(data);
return decode_webp(data, mode);
}

TORCH_CHECK(false, err_msg);
Expand Down
46 changes: 39 additions & 7 deletions torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ namespace vision {
namespace image {

#if !WEBP_FOUND
torch::Tensor decode_webp(const torch::Tensor& data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(
false, "decode_webp: torchvision not compiled with libwebp support");
}
#else

torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
Expand All @@ -26,13 +30,41 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
encoded_data.dim(),
" dims.");

auto encoded_data_p = encoded_data.data_ptr<uint8_t>();
auto encoded_data_size = encoded_data.numel();

WebPBitstreamFeatures features;
auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features);
TORCH_CHECK(
res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res);
TORCH_CHECK(
!features.has_animation, "Animated webp files are not supported.");

auto decoding_func = WebPDecodeRGB;
int num_channels = 0;
if (mode == IMAGE_READ_MODE_RGB) {
decoding_func = WebPDecodeRGB;
num_channels = 3;
} else if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
decoding_func = WebPDecodeRGBA;
num_channels = 4;
} else {
// Assume mode is "unchanged"
decoding_func = features.has_alpha ? WebPDecodeRGBA : WebPDecodeRGB;
num_channels = features.has_alpha ? 4 : 3;
}

int width = 0;
int height = 0;
auto decoded_data = WebPDecodeRGB(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
return out.permute({2, 0, 1}); // return CHW, channels-last

auto decoded_data =
decoding_func(encoded_data_p, encoded_data_size, &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");

auto out = torch::from_blob(
decoded_data, {height, width, num_channels}, torch::kUInt8);

return out.permute({2, 0, 1});
}
#endif // WEBP_FOUND

Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_webp.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);
C10_EXPORT torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static auto registry =
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
&decode_webp)
.op("image::decode_avif", &decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
Expand Down
16 changes: 12 additions & 4 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ImageReadMode(Enum):
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
.. note::
Some decoders won't support all possible values, e.g. a decoder may only
support "RGB" and "RGBA" mode.
"""

UNCHANGED = 0
Expand Down Expand Up @@ -365,23 +370,26 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor:

def decode_webp(
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
) -> torch.Tensor:
"""
Decode a WEBP image into a 3 dimensional RGB Tensor.
Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
The values of the output tensor are uint8 between 0 and 255. If the input
image is RGBA, the transparency is ignored.
The values of the output tensor are uint8 between 0 and 255.
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the WEBP image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_webp(input)
return torch.ops.image.decode_webp(input, mode.value)


def _decode_avif(
Expand Down

0 comments on commit bdf354e

Please sign in to comment.