Skip to content

Commit

Permalink
Fix webdataset pickling (#6972)
Browse files Browse the repository at this point in the history
* fix webdataset pickling

* more general fix
  • Loading branch information
lhoestq authored Jun 14, 2024
1 parent ef2fb35 commit 5e72fb1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 26 deletions.
8 changes: 5 additions & 3 deletions src/datasets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import config
from .table import CastError
from .utils.deprecation_utils import deprecated
from .utils.track import TrackedIterable, tracked_list, tracked_str
from .utils.track import TrackedIterableFromGenerator, tracked_list, tracked_str


class DatasetsError(Exception):
Expand Down Expand Up @@ -65,9 +65,11 @@ def from_cast_error(
)
formatted_tracked_gen_kwargs: List[str] = []
for gen_kwarg in gen_kwargs.values():
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterable)):
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterableFromGenerator)):
continue
while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None:
while (
isinstance(gen_kwarg, (tracked_list, TrackedIterableFromGenerator)) and gen_kwarg.last_item is not None
):
gen_kwarg = gen_kwarg.last_item
if isinstance(gen_kwarg, tracked_str):
gen_kwarg = gen_kwarg.get_origin()
Expand Down
24 changes: 4 additions & 20 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from io import BytesIO
from itertools import chain
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from unittest.mock import patch
from urllib.parse import urljoin, urlparse
from xml.etree import ElementTree as ET
Expand All @@ -47,7 +47,7 @@
from . import tqdm as hf_tqdm
from ._filelock import FileLock
from .extract import ExtractManager
from .track import TrackedIterable
from .track import TrackedIterableFromGenerator


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -1564,23 +1564,7 @@ def xxml_dom_minidom_parse(filename_or_file, download_config: Optional[DownloadC
return xml.dom.minidom.parse(f, **kwargs)


class _IterableFromGenerator(TrackedIterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator: Callable, *args, **kwargs):
super().__init__()
self.generator = generator
self.args = args
self.kwargs = kwargs

def __iter__(self):
for x in self.generator(*self.args, **self.kwargs):
self.last_item = x
yield x
self.last_item = None


class ArchiveIterable(_IterableFromGenerator):
class ArchiveIterable(TrackedIterableFromGenerator):
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""

@staticmethod
Expand Down Expand Up @@ -1645,7 +1629,7 @@ def from_urlpath(cls, urlpath_or_buf, download_config: Optional[DownloadConfig]
return cls(cls._iter_from_urlpath, urlpath_or_buf, download_config)


class FilesIterable(_IterableFromGenerator):
class FilesIterable(TrackedIterableFromGenerator):
"""An iterable of paths from a list of directories or files"""

@classmethod
Expand Down
19 changes: 16 additions & 3 deletions src/datasets/utils/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,26 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(current={self.last_item})"


class TrackedIterable(Iterable):
def __init__(self) -> None:
class TrackedIterableFromGenerator(Iterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator, *args):
super().__init__()
self.generator = generator
self.args = args
self.last_item = None

def __iter__(self):
for x in self.generator(*self.args):
self.last_item = x
yield x
self.last_item = None

def __repr__(self) -> str:
if self.last_item is None:
super().__repr__()
return super().__repr__()
else:
return f"{self.__class__.__name__}(current={self.last_item})"

def __reduce__(self):
return (self.__class__, (self.generator, *self.args))

0 comments on commit 5e72fb1

Please sign in to comment.