Skip to content

Commit

Permalink
support map and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 16, 2024
1 parent fe601ba commit 36a49b1
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,9 +1230,9 @@ def save_to_disk(
"""
Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`.
For [`Image`] and [`Audio`] data:
For [`Image`], [`Audio`] and [`Video`] data:
All the Image() and Audio() data are stored in the arrow files.
All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
Expand Down
74 changes: 47 additions & 27 deletions src/datasets/features/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,6 @@
from .features import FeatureType


def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None:
if hasattr(uri, "read"):
self._hf_encoded = {"bytes": uri.read(), "path": None}
uri.seek(0)
elif isinstance(uri, str):
self._hf_encoded = {"bytes": None, "path": uri}
self._original_init(uri, *args, **kwargs)


def patch_decord():
if config.DECORD_AVAILABLE:
from decord import VideoReader
else:
raise ImportError("To support decoding videos, please install 'decord'.")
if not hasattr(VideoReader, "_hf_patched"):
VideoReader._original_init = VideoReader.__init__
VideoReader.__init__ = _patched_init
VideoReader._hf_patched = True


if config.DECORD_AVAILABLE:
patch_decord()


@dataclass
class Video:
"""Video [`Feature`] to read video data from an video file.
Expand Down Expand Up @@ -104,7 +80,6 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader
if config.DECORD_AVAILABLE:
from decord import VideoReader

patch_decord()
else:
raise ImportError("To support encoding videos, please install 'decord'.")

Expand Down Expand Up @@ -154,7 +129,6 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.")

if config.DECORD_AVAILABLE:
patch_decord()
from decord import VideoReader
else:
raise ImportError("To support decoding videos, please install 'decord'.")
Expand Down Expand Up @@ -298,4 +272,50 @@ def encode_decord_video(video: "VideoReader") -> dict:

def encode_np_array(array: np.ndarray) -> dict:
raise NotImplementedError()
return {"path": None, "bytes": video_to_bytes(video)}


# Patching decord a little bit to:
# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded``
# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global)
# This doesn't affect the normal usage of decord.


def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None:
from decord.bridge import bridge_out

if hasattr(uri, "read"):
self._hf_encoded = {"bytes": uri.read(), "path": None}
uri.seek(0)
elif isinstance(uri, str):
self._hf_encoded = {"bytes": None, "path": uri}
self._hf_bridge_out = bridge_out
self._original_init(uri, *args, **kwargs)


def _patched_next(self: "VideoReader", *args, **kwargs):
return self._hf_bridge_out(self._original_next(*args, **kwargs))


def _patched_get_batch(self: "VideoReader", *args, **kwargs):
return self._hf_bridge_out(self._original_get_batch(*args, **kwargs))


def patch_decord():
if config.DECORD_AVAILABLE:
import decord.video_reader
from decord import VideoReader
else:
raise ImportError("To support decoding videos, please install 'decord'.")
if not hasattr(VideoReader, "_hf_patched"):
decord.video_reader.bridge_out = lambda x: x
VideoReader._original_init = VideoReader.__init__
VideoReader.__init__ = _patched_init
VideoReader._original_next = VideoReader.next
VideoReader.next = _patched_next
VideoReader._original_get_batch = VideoReader.get_batch
VideoReader.get_batch = _patched_get_batch
VideoReader._hf_patched = True


if config.DECORD_AVAILABLE:
patch_decord()
6 changes: 6 additions & 0 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def _tensorize(self, value):

if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
from decord import VideoReader

if isinstance(value, VideoReader):
value._hf_bridge_out = lambda x: jnp.array(np.asarray(x))
return value

# using global variable since `jaxlib.xla_extension.Device` is not serializable neither
# with `pickle` nor with `dill`, so we need to use a global variable instead
Expand Down
6 changes: 6 additions & 0 deletions src/datasets/formatting/np_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def _tensorize(self, value):

if isinstance(value, PIL.Image.Image):
return np.asarray(value, **self.np_array_kwargs)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
from decord import VideoReader

if isinstance(value, VideoReader):
value._hf_bridge_out = np.asarray
return value

return np.asarray(value, **{**default_dtype, **self.np_array_kwargs})

Expand Down
7 changes: 7 additions & 0 deletions src/datasets/formatting/tf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _tensorize(self, value):

if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
from decord import VideoReader
from decord.bridge import to_tensorflow

if isinstance(value, VideoReader):
value._hf_bridge_out = to_tensorflow
return value

return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs})

Expand Down
8 changes: 8 additions & 0 deletions src/datasets/formatting/torch_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def _tensorize(self, value):
value = value[:, :, np.newaxis]

value = value.transpose((2, 0, 1))
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
from decord import VideoReader
from decord.bridge import to_torch

if isinstance(value, VideoReader):
value._hf_bridge_out = to_torch
return value

return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

def _recursive_tensorize(self, data_struct):
Expand Down
6 changes: 5 additions & 1 deletion src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .parquet import parquet
from .sql import sql
from .text import text
from .videofolder import videofolder
from .webdataset import webdataset


Expand All @@ -40,6 +41,7 @@ def _hash_python_lines(lines: List[str]) -> str:
"text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())),
"imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())),
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
"videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
}

Expand Down Expand Up @@ -74,7 +76,9 @@ def _hash_python_lines(lines: List[str]) -> str:
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"}
_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"}

# Used to filter data files based on extensions given a module name
_MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {}
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions src/datasets/packaged_modules/videofolder/videofolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

import datasets

from ..folder_based_builder import folder_based_builder


logger = datasets.utils.logging.get_logger(__name__)


class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
"""BuilderConfig for ImageFolder."""

drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class VideoFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Video
BASE_COLUMN_NAME = "video"
BUILDER_CONFIG_CLASS = VideoFolderConfig
EXTENSIONS: List[str] # definition at the bottom of the script


# TODO: initial list, we should check the compatibility of other formats
VIDEO_EXTENSIONS = [
".mkv",
".mp4",
".avi",
".mpeg",
".mov",
]
VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS
Binary file added tests/features/data/test_video_66x50.mov
Binary file not shown.
92 changes: 92 additions & 0 deletions tests/features/test_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest

from datasets import Dataset, Features, Video

from ..utils import require_decord


@require_decord
@pytest.mark.parametrize(
"build_example",
[
lambda video_path: video_path,
lambda video_path: open(video_path, "rb").read(),
lambda video_path: {"path": video_path},
lambda video_path: {"path": video_path, "bytes": None},
lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()},
lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()},
lambda video_path: {"bytes": open(video_path, "rb").read()},
],
)
def test_video_feature_encode_example(shared_datadir, build_example):
from decord import VideoReader

video_path = str(shared_datadir / "test_video_66x50.mov")
video = Video()
encoded_example = video.encode_example(build_example(video_path))
assert isinstance(encoded_example, dict)
assert encoded_example.keys() == {"bytes", "path"}
assert encoded_example["bytes"] is not None or encoded_example["path"] is not None
decoded_example = video.decode_example(encoded_example)
assert isinstance(decoded_example, VideoReader)


@require_decord
def test_dataset_with_video_feature(shared_datadir):
from decord import VideoReader
from decord.ndarray import NDArray

video_path = str(shared_datadir / "test_video_66x50.mov")
data = {"video": [video_path]}
features = Features({"video": Video()})
dset = Dataset.from_dict(data, features=features)
item = dset[0]
assert item.keys() == {"video"}
assert isinstance(item["video"], VideoReader)
assert item["video"][0].shape == (50, 66, 3)
assert isinstance(item["video"][0], NDArray)
batch = dset[:1]
assert len(batch) == 1
assert batch.keys() == {"video"}
assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"])
assert batch["video"][0][0].shape == (50, 66, 3)
assert isinstance(batch["video"][0][0], NDArray)
column = dset["video"]
assert len(column) == 1
assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column)
assert column[0][0].shape == (50, 66, 3)
assert isinstance(column[0][0], NDArray)

# from bytes
with open(video_path, "rb") as f:
data = {"video": [f.read()]}
dset = Dataset.from_dict(data, features=features)
item = dset[0]
assert item.keys() == {"video"}
assert isinstance(item["video"], VideoReader)
assert item["video"][0].shape == (50, 66, 3)
assert isinstance(item["video"][0], NDArray)


@require_decord
def test_dataset_with_video_map_and_formatted(shared_datadir):
import numpy as np
from decord import VideoReader

video_path = str(shared_datadir / "test_video_66x50.mov")
data = {"video": [video_path]}
features = Features({"video": Video()})
dset = Dataset.from_dict(data, features=features)
dset = dset.map(lambda x: x).with_format("numpy")
example = dset[0]
assert isinstance(example["video"], VideoReader)
assert isinstance(example["video"][0], np.ndarray)

# from bytes
with open(video_path, "rb") as f:
data = {"video": [f.read()]}
dset = Dataset.from_dict(data, features=features)
dset = dset.map(lambda x: x).with_format("numpy")
example = dset[0]
assert isinstance(example["video"], VideoReader)
assert isinstance(example["video"][0], np.ndarray)
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ def require_pil(test_case):
return test_case


def require_decord(test_case):
"""
Decorator marking a test that requires decord.
These tests are skipped when decord isn't installed.
"""
if not config.PIL_AVAILABLE:
test_case = unittest.skip("test requires decord")(test_case)
return test_case


def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.
Expand Down

0 comments on commit 36a49b1

Please sign in to comment.