Skip to content

Commit

Permalink
import torch before decord to fix random_device could not be read
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 23, 2024
1 parent d263ad7 commit 99f4ba9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
# "decord==0.6.0",
"decord==0.6.0",
]


Expand Down
15 changes: 15 additions & 0 deletions src/datasets/features/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader
`dict` with "path" and "bytes" fields
"""
if config.DECORD_AVAILABLE:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
from decord import VideoReader

else:
Expand Down Expand Up @@ -129,6 +134,11 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.")

if config.DECORD_AVAILABLE:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
from decord import VideoReader
else:
raise ImportError("To support decoding videos, please install 'decord'.")
Expand Down Expand Up @@ -302,6 +312,11 @@ def _patched_get_batch(self: "VideoReader", *args, **kwargs):

def patch_decord():
if config.DECORD_AVAILABLE:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
import decord.video_reader
from decord import VideoReader
else:
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _tensorize(self, value):
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
from decord import VideoReader

if isinstance(value, VideoReader):
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/formatting/np_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def _tensorize(self, value):
if isinstance(value, PIL.Image.Image):
return np.asarray(value, **self.np_array_kwargs)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
from decord import VideoReader

if isinstance(value, VideoReader):
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/formatting/tf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def _tensorize(self, value):
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
elif config.DECORD_AVAILABLE and "decord" in sys.modules:
# We need to import torch first, otherwise later it can cause issues
# e.g. "RuntimeError: random_device could not be read"
# when running `torch.tensor(value).share_memory_()`
if config.TORCH_AVAILABLE:
import torch # noqa
from decord import VideoReader
from decord.bridge import to_tensorflow

Expand Down

0 comments on commit 99f4ba9

Please sign in to comment.