diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..40ca0cd7312 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1230,9 +1230,9 @@ def save_to_disk( """ Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index 5287630bbfc..ebdd1eb4dcb 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -19,30 +19,6 @@ from .features import FeatureType -def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: - if hasattr(uri, "read"): - self._hf_encoded = {"bytes": uri.read(), "path": None} - uri.seek(0) - elif isinstance(uri, str): - self._hf_encoded = {"bytes": None, "path": uri} - self._original_init(uri, *args, **kwargs) - - -def patch_decord(): - if config.DECORD_AVAILABLE: - from decord import VideoReader - else: - raise ImportError("To support decoding videos, please install 'decord'.") - if not hasattr(VideoReader, "_hf_patched"): - VideoReader._original_init = VideoReader.__init__ - VideoReader.__init__ = _patched_init - VideoReader._hf_patched = True - - -if config.DECORD_AVAILABLE: - patch_decord() - - @dataclass class Video: """Video [`Feature`] to read video data from an video file. @@ -104,7 +80,6 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader if config.DECORD_AVAILABLE: from decord import VideoReader - patch_decord() else: raise ImportError("To support encoding videos, please install 'decord'.") @@ -154,7 +129,6 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.") if config.DECORD_AVAILABLE: - patch_decord() from decord import VideoReader else: raise ImportError("To support decoding videos, please install 'decord'.") @@ -298,4 +272,50 @@ def encode_decord_video(video: "VideoReader") -> dict: def encode_np_array(array: np.ndarray) -> dict: raise NotImplementedError() - return {"path": None, "bytes": video_to_bytes(video)} + + +# Patching decord a little bit to: +# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded`` +# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global) +# This doesn't affect the normal usage of decord. + + +def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: + from decord.bridge import bridge_out + + if hasattr(uri, "read"): + self._hf_encoded = {"bytes": uri.read(), "path": None} + uri.seek(0) + elif isinstance(uri, str): + self._hf_encoded = {"bytes": None, "path": uri} + self._hf_bridge_out = bridge_out + self._original_init(uri, *args, **kwargs) + + +def _patched_next(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_next(*args, **kwargs)) + + +def _patched_get_batch(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_get_batch(*args, **kwargs)) + + +def patch_decord(): + if config.DECORD_AVAILABLE: + import decord.video_reader + from decord import VideoReader + else: + raise ImportError("To support decoding videos, please install 'decord'.") + if not hasattr(VideoReader, "_hf_patched"): + decord.video_reader.bridge_out = lambda x: x + VideoReader._original_init = VideoReader.__init__ + VideoReader.__init__ = _patched_init + VideoReader._original_next = VideoReader.next + VideoReader.next = _patched_next + VideoReader._original_get_batch = VideoReader.get_batch + VideoReader.get_batch = _patched_get_batch + VideoReader._hf_patched = True + + +if config.DECORD_AVAILABLE: + patch_decord() diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py index 8035341c5cd..accb9c52499 100644 --- a/src/datasets/formatting/jax_formatter.py +++ b/src/datasets/formatting/jax_formatter.py @@ -105,6 +105,12 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): value = np.asarray(value) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = lambda x: jnp.array(np.asarray(x)) + return value # using global variable since `jaxlib.xla_extension.Device` is not serializable neither # with `pickle` nor with `dill`, so we need to use a global variable instead diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py index 95bcff2b517..3f03b06523b 100644 --- a/src/datasets/formatting/np_formatter.py +++ b/src/datasets/formatting/np_formatter.py @@ -62,6 +62,12 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): return np.asarray(value, **self.np_array_kwargs) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = np.asarray + return value return np.asarray(value, **{**default_dtype, **self.np_array_kwargs}) diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py index adb15cda381..650fb0984cc 100644 --- a/src/datasets/formatting/tf_formatter.py +++ b/src/datasets/formatting/tf_formatter.py @@ -69,6 +69,13 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): value = np.asarray(value) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_tensorflow + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_tensorflow + return value return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs}) diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py index 8efe759a144..1fd1701ae6b 100644 --- a/src/datasets/formatting/torch_formatter.py +++ b/src/datasets/formatting/torch_formatter.py @@ -75,6 +75,14 @@ def _tensorize(self, value): value = value[:, :, np.newaxis] value = value.transpose((2, 0, 1)) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_torch + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_torch + return value + return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) def _recursive_tensorize(self, data_struct): diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 6a23170db5e..6d89270dbe3 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -14,6 +14,7 @@ from .parquet import parquet from .sql import sql from .text import text +from .videofolder import videofolder from .webdataset import webdataset @@ -40,6 +41,7 @@ def _hash_python_lines(lines: List[str]) -> str: "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), + "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), } @@ -74,7 +76,9 @@ def _hash_python_lines(lines: List[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) -_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"} +_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"} # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {} diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py new file mode 100644 index 00000000000..7ce5bcf5655 --- /dev/null +++ b/src/datasets/packaged_modules/videofolder/videofolder.py @@ -0,0 +1,36 @@ +from typing import List + +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for ImageFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class VideoFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Video + BASE_COLUMN_NAME = "video" + BUILDER_CONFIG_CLASS = VideoFolderConfig + EXTENSIONS: List[str] # definition at the bottom of the script + + +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov new file mode 100644 index 00000000000..a55dcaa8f7b Binary files /dev/null and b/tests/features/data/test_video_66x50.mov differ diff --git a/tests/features/test_video.py b/tests/features/test_video.py new file mode 100644 index 00000000000..f4c9a8d830b --- /dev/null +++ b/tests/features/test_video.py @@ -0,0 +1,92 @@ +import pytest + +from datasets import Dataset, Features, Video + +from ..utils import require_decord + + +@require_decord +@pytest.mark.parametrize( + "build_example", + [ + lambda video_path: video_path, + lambda video_path: open(video_path, "rb").read(), + lambda video_path: {"path": video_path}, + lambda video_path: {"path": video_path, "bytes": None}, + lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"bytes": open(video_path, "rb").read()}, + ], +) +def test_video_feature_encode_example(shared_datadir, build_example): + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + video = Video() + encoded_example = video.encode_example(build_example(video_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = video.decode_example(encoded_example) + assert isinstance(decoded_example, VideoReader) + + +@require_decord +def test_dataset_with_video_feature(shared_datadir): + from decord import VideoReader + from decord.ndarray import NDArray + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"video"} + assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"]) + assert batch["video"][0][0].shape == (50, 66, 3) + assert isinstance(batch["video"][0][0], NDArray) + column = dset["video"] + assert len(column) == 1 + assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column) + assert column[0][0].shape == (50, 66, 3) + assert isinstance(column[0][0], NDArray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + + +@require_decord +def test_dataset_with_video_map_and_formatted(shared_datadir): + import numpy as np + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) diff --git a/tests/utils.py b/tests/utils.py index e19740a2a12..b9f6ff15e0a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,18 @@ def require_pil(test_case): return test_case +def require_decord(test_case): + """ + Decorator marking a test that requires decord. + + These tests are skipped when decord isn't installed. + + """ + if not config.PIL_AVAILABLE: + test_case = unittest.skip("test requires decord")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.