From 36a49b11e224147f4a9a9e3696c3685037825b6a Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 16 Oct 2024 17:20:59 +0200 Subject: [PATCH] support map and formatting --- src/datasets/dataset_dict.py | 4 +- src/datasets/features/video.py | 74 +++++++++----- src/datasets/formatting/jax_formatter.py | 6 ++ src/datasets/formatting/np_formatter.py | 6 ++ src/datasets/formatting/tf_formatter.py | 7 ++ src/datasets/formatting/torch_formatter.py | 8 ++ src/datasets/packaged_modules/__init__.py | 6 +- .../packaged_modules/videofolder/__init__.py | 0 .../videofolder/videofolder.py | 36 +++++++ tests/features/data/test_video_66x50.mov | Bin 0 -> 44391 bytes tests/features/test_video.py | 92 ++++++++++++++++++ tests/utils.py | 12 +++ 12 files changed, 221 insertions(+), 30 deletions(-) create mode 100644 src/datasets/packaged_modules/videofolder/__init__.py create mode 100644 src/datasets/packaged_modules/videofolder/videofolder.py create mode 100644 tests/features/data/test_video_66x50.mov create mode 100644 tests/features/test_video.py diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..40ca0cd7312 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1230,9 +1230,9 @@ def save_to_disk( """ Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index 5287630bbfc..ebdd1eb4dcb 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -19,30 +19,6 @@ from .features import FeatureType -def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: - if hasattr(uri, "read"): - self._hf_encoded = {"bytes": uri.read(), "path": None} - uri.seek(0) - elif isinstance(uri, str): - self._hf_encoded = {"bytes": None, "path": uri} - self._original_init(uri, *args, **kwargs) - - -def patch_decord(): - if config.DECORD_AVAILABLE: - from decord import VideoReader - else: - raise ImportError("To support decoding videos, please install 'decord'.") - if not hasattr(VideoReader, "_hf_patched"): - VideoReader._original_init = VideoReader.__init__ - VideoReader.__init__ = _patched_init - VideoReader._hf_patched = True - - -if config.DECORD_AVAILABLE: - patch_decord() - - @dataclass class Video: """Video [`Feature`] to read video data from an video file. @@ -104,7 +80,6 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader if config.DECORD_AVAILABLE: from decord import VideoReader - patch_decord() else: raise ImportError("To support encoding videos, please install 'decord'.") @@ -154,7 +129,6 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.") if config.DECORD_AVAILABLE: - patch_decord() from decord import VideoReader else: raise ImportError("To support decoding videos, please install 'decord'.") @@ -298,4 +272,50 @@ def encode_decord_video(video: "VideoReader") -> dict: def encode_np_array(array: np.ndarray) -> dict: raise NotImplementedError() - return {"path": None, "bytes": video_to_bytes(video)} + + +# Patching decord a little bit to: +# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded`` +# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global) +# This doesn't affect the normal usage of decord. + + +def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: + from decord.bridge import bridge_out + + if hasattr(uri, "read"): + self._hf_encoded = {"bytes": uri.read(), "path": None} + uri.seek(0) + elif isinstance(uri, str): + self._hf_encoded = {"bytes": None, "path": uri} + self._hf_bridge_out = bridge_out + self._original_init(uri, *args, **kwargs) + + +def _patched_next(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_next(*args, **kwargs)) + + +def _patched_get_batch(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_get_batch(*args, **kwargs)) + + +def patch_decord(): + if config.DECORD_AVAILABLE: + import decord.video_reader + from decord import VideoReader + else: + raise ImportError("To support decoding videos, please install 'decord'.") + if not hasattr(VideoReader, "_hf_patched"): + decord.video_reader.bridge_out = lambda x: x + VideoReader._original_init = VideoReader.__init__ + VideoReader.__init__ = _patched_init + VideoReader._original_next = VideoReader.next + VideoReader.next = _patched_next + VideoReader._original_get_batch = VideoReader.get_batch + VideoReader.get_batch = _patched_get_batch + VideoReader._hf_patched = True + + +if config.DECORD_AVAILABLE: + patch_decord() diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py index 8035341c5cd..accb9c52499 100644 --- a/src/datasets/formatting/jax_formatter.py +++ b/src/datasets/formatting/jax_formatter.py @@ -105,6 +105,12 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): value = np.asarray(value) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = lambda x: jnp.array(np.asarray(x)) + return value # using global variable since `jaxlib.xla_extension.Device` is not serializable neither # with `pickle` nor with `dill`, so we need to use a global variable instead diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py index 95bcff2b517..3f03b06523b 100644 --- a/src/datasets/formatting/np_formatter.py +++ b/src/datasets/formatting/np_formatter.py @@ -62,6 +62,12 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): return np.asarray(value, **self.np_array_kwargs) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = np.asarray + return value return np.asarray(value, **{**default_dtype, **self.np_array_kwargs}) diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py index adb15cda381..650fb0984cc 100644 --- a/src/datasets/formatting/tf_formatter.py +++ b/src/datasets/formatting/tf_formatter.py @@ -69,6 +69,13 @@ def _tensorize(self, value): if isinstance(value, PIL.Image.Image): value = np.asarray(value) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_tensorflow + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_tensorflow + return value return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs}) diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py index 8efe759a144..1fd1701ae6b 100644 --- a/src/datasets/formatting/torch_formatter.py +++ b/src/datasets/formatting/torch_formatter.py @@ -75,6 +75,14 @@ def _tensorize(self, value): value = value[:, :, np.newaxis] value = value.transpose((2, 0, 1)) + elif config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_torch + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_torch + return value + return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) def _recursive_tensorize(self, data_struct): diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 6a23170db5e..6d89270dbe3 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -14,6 +14,7 @@ from .parquet import parquet from .sql import sql from .text import text +from .videofolder import videofolder from .webdataset import webdataset @@ -40,6 +41,7 @@ def _hash_python_lines(lines: List[str]) -> str: "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), + "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), } @@ -74,7 +76,9 @@ def _hash_python_lines(lines: List[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) -_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"} +_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"} # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {} diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py new file mode 100644 index 00000000000..7ce5bcf5655 --- /dev/null +++ b/src/datasets/packaged_modules/videofolder/videofolder.py @@ -0,0 +1,36 @@ +from typing import List + +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for ImageFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class VideoFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Video + BASE_COLUMN_NAME = "video" + BUILDER_CONFIG_CLASS = VideoFolderConfig + EXTENSIONS: List[str] # definition at the bottom of the script + + +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov new file mode 100644 index 0000000000000000000000000000000000000000..a55dcaa8f7be611cd3473cb1ffaf123d5a21df2d GIT binary patch literal 44391 zcmeI5dsxg_AHaXVQcWRURZ4!7Frur_iq-0(q|H_;avL*~ZqtmJk+RZ+rP{63=Ds2l zU8oRcwUsWEyLDTU5WS)%WG!#o?>%R#NADK9|Ge+>y!(E?^L#t!e9z^3exKi*bDrn> zID;Svoj^&HaHWKz2qZlzV|6f>M-Urq`CPUH+d+T;5C8%|00;m9AOHk_01yBIKmZ5; z0U!VbfB+Bx0zd!=00AHX1b_e#00KY&2mk>f00e*l5C8%|00;m9AOHk_01yBIKmZ5; z0U!VbfB+Bx0zd!=00AHX1b_e#00KY&2mpcqC;?+-r4h4z=iFN~*KuZ=tXgyWHOIpw zeeoi#KoF@NQsmojOfjWUMhJFT4yG%c6AuZ4bdy!Cz4DrJc#6!1NTiGozdL)S(7iuveUV+FTNV8al%yI z?U9BDouuVx8M`hdbAIC3^d+aqEBq8^6w4WUVPbc;%D5|HGqGXZC9m_vS{Y7Tu$B2( z-;Gvz_Ml|)sE3YH)gO))(pl1a`RP)O^*aR{PAm(yG=ek=fUz(x4eSt#dZ6|NUtEGGR!Ovh|Wu zgOar~PRuQza+sYtrTw{)-s|#Z1?Lp^%{-f_rsA%>@>cqt-;6dJn`YeIM$YwGF34=^ zDmZIodH-^kerrQ{MgFdqHP>B@ChU0d&9B7Kd5>gmzDrJ&r0>YT8MX3NI#v7~`!+E` zsL&Kp^qb=6*5N;IIx%^%o|1rhOuc;1{Uo#5-|J`Gt$SnMZkTi|V9r>6eh{VV^mNBo z@0|{L+H>9fr|pyxtl_0T^@ZvkLXEcY#Db~coRo1Z=U%a&WqMekxGI0nQq$u#PjBa+ z-V%}OBV0E}npBq@sc6VKsF+w3X<05>@AuNHwaecBYT@en-;+6U2U_nRw`^n1IJ)Z# zugqIVn!@kZ+j&niy1$;qt?(Y%aPn!o$T+^{SXrQ9Q-n(Ni^Gjw?JjqHJHA#Abunk5 zH#>&Rj*m;0scCvPdo%14W*fX4lI9~jzp>Isa~M1Gz*1fPLZP|_YA9kpwQ`C%b0bdR zZvCBwh02&o*8?IKUa8V@A9K#qB(#;&`GP7Nad61{G2Wh!b6zocd(!O1Ce^vOLK8)G zTW%~HRH5))0J}+c#yLW46V{6N#82aLx72csW>u0@DmRFk`|BW*1M-vV`)8@-bzj#LtT2Rh-rF!@F{-kjp4r(Vimr}zm9ja$6L!*yk zHru(x95flh9_;6HrYq5%Qw40{F$oqz4 z_2XC_s79ap4+rNM_1EbN`HK$ao9sT%l5FvIt|z)sofetnDc=WXO^m40)iC+_YWISE z=IRtPvZlPef)#01k* z^N>}O%=sBFe$(j^=Tt(g%DNF5M>3!L<9(uDai!mefH$+Vr~{iCOEc;-_K$b;dMfXg zs~eZhHSuL1v1=hTSWgSEk@;eo-=JTc-0tP4HxfnB$_~hAK<4_zmn|Pl zGe2KZ^T@4)nc?e%y9{Y$S7c#?M@o-aKiD=qNEH>nNV48M95e zw6u;eNN{<#FJ)>#lx=*jd!eI2ZcWIuxr1MG7DRM3B&xY&zISox9-=;h_P<*BCOldo zReQUl7#X#hvPgHDtQ(|pdu+vXLO~KPCPg{!DR)|g@qGNwgn++`$nrQ2sbPa1Zxep>58X03w?{{+Y0L*uh8bq1D8%QmRoeqDcZdzwM4XU+}NNjqOHeyMn=Ve_6hS1O{jJGHc{ zn=_aPW-ccbE}wa_eoDnDv-L#NLxl=mG*caPn0V5p{rn%az708|7eXinMi(KwCOSj< z;jm^t(r~AD4<>y#cU&BDRfW6Idx1LNbix{ap>rjHK4pBKu};%uB||2tkEBmz(*jX) z0%a@j)f00e*l5C8%|00;m9AOHk_01yBIKmZ5;0U!Vb zfB+Bx0zd!=0D=E60_c7xr9{L1PPpHRFGIuq&JTBW``(BB%lA9Y^!Nfn1lmHv_z^)| zyp=qjyg3?a3_<7#34&q}B>JU?zI)8S6;Q<=$_{0qcyMC?#hwMY^w_c`o? z&)(6H-#^M9{Pec$ZRUV8vyj^04B;eU3neJk;e_!5<(3b1Izo}4pN$vZ&t}TwO2nwI z5id-P$HIbUf#*hcbo*D%<8y=A$WHOO1My>a&aaa9vku~hiExe|6efxY=JMD=VVJxI z{n2m~h=G`lR!EPVRuDYtXN1TR|}gd_}=)(M4zTn6$_2(NO};1>4${oIF#nHH?>oa+VeCt z3>XZO7{qFuE=BFMID#-yI46vQ=7K-nqA!-nUl1-~n;%j%XNB{e4|Z(pZI4s@Oe9Uy zKlIrD?q&3t^5=)={NLk1al~OO@C!u3?fK5)_lShqiT)U&HHz0Q7Dc4^`%NL8fs`ZW zO3((53>x~M`h2qb6P0LW7IM$QE|B1T0K1>wlj5>rcb*iYc56_+SMZXIM#@Awyibbr zBajAk+sP`Fjz@~mrzJ>7BgOUdpFu(?Zr`i;7d7yBIK?`$Pu`*@uZL3HCI_j0k8Lta z@ff&{>Pc~bQ=~Y>_E9}4E}I}7$WHx<{lDoGZHpXqgvbTG2M3gq;zBB)Ct;)1r1$v1 zMm|?=!`TmF7s`w3{bG^G%)+fW0=@-4Xm}PYR|RuICBb~2g)lgh7slmUjVo#sQ#2^F&#oY&vOz_#EIHJJ(XuW{T!*OoYk2Mer0wt^2 zB0R`P!ygU)5Fh{qfB+Bx0zd!=00AHX1b_e#00Mt^0<4~YXhBIt TJp7-XV8M|JtX!ngB+C5@O%8-C literal 0 HcmV?d00001 diff --git a/tests/features/test_video.py b/tests/features/test_video.py new file mode 100644 index 00000000000..f4c9a8d830b --- /dev/null +++ b/tests/features/test_video.py @@ -0,0 +1,92 @@ +import pytest + +from datasets import Dataset, Features, Video + +from ..utils import require_decord + + +@require_decord +@pytest.mark.parametrize( + "build_example", + [ + lambda video_path: video_path, + lambda video_path: open(video_path, "rb").read(), + lambda video_path: {"path": video_path}, + lambda video_path: {"path": video_path, "bytes": None}, + lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"bytes": open(video_path, "rb").read()}, + ], +) +def test_video_feature_encode_example(shared_datadir, build_example): + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + video = Video() + encoded_example = video.encode_example(build_example(video_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = video.decode_example(encoded_example) + assert isinstance(decoded_example, VideoReader) + + +@require_decord +def test_dataset_with_video_feature(shared_datadir): + from decord import VideoReader + from decord.ndarray import NDArray + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"video"} + assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"]) + assert batch["video"][0][0].shape == (50, 66, 3) + assert isinstance(batch["video"][0][0], NDArray) + column = dset["video"] + assert len(column) == 1 + assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column) + assert column[0][0].shape == (50, 66, 3) + assert isinstance(column[0][0], NDArray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + + +@require_decord +def test_dataset_with_video_map_and_formatted(shared_datadir): + import numpy as np + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) diff --git a/tests/utils.py b/tests/utils.py index e19740a2a12..b9f6ff15e0a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,18 @@ def require_pil(test_case): return test_case +def require_decord(test_case): + """ + Decorator marking a test that requires decord. + + These tests are skipped when decord isn't installed. + + """ + if not config.PIL_AVAILABLE: + test_case = unittest.skip("test requires decord")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.