Skip to content

Commit

Permalink
fix iter_arrow resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed May 30, 2024
1 parent c323af0 commit c35a036
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,16 @@ class _BaseExamplesIterable:
"""Base class for the examples iterable used by an IterableDataset"""

def __init__(self) -> None:
self.iter_arrow: Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]] = None
self._state_dict: Optional[Union[list, dict]] = None

def __iter__(self) -> Iterator[Tuple[Key, dict]]:
"""An examples iterable should yield tuples (example_key, example) of type (int/str, dict)"""
raise NotImplementedError(f"{type(self)} doesn't implement __iter__ yet")

@property
def iter_arrow(self) -> Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]]:
return None

def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamplesIterable":
"""
Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable.
Expand All @@ -200,11 +203,11 @@ def _init_state_dict(self) -> dict:

def load_state_dict(self, state_dict: dict) -> dict:
def _inner_load_state_dict(state, new_state):
if new_state and isinstance(state, dict):
if new_state is not None and isinstance(state, dict):
for key in state:
state[key] = _inner_load_state_dict(state[key], new_state[key])
return state
elif new_state and isinstance(state, list):
elif new_state is not None and isinstance(state, list):
for i in range(len(state)):
state[i] = _inner_load_state_dict(state[i], new_state[i])
return state
Expand Down Expand Up @@ -297,7 +300,10 @@ def __init__(self, generate_tables_fn: Callable[..., Tuple[Key, pa.Table]], kwar
super().__init__()
self.generate_tables_fn = generate_tables_fn
self.kwargs = kwargs
self.iter_arrow = self._iter_arrow

@property
def iter_arrow(self):
return self._iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0}
Expand Down Expand Up @@ -417,8 +423,11 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]):
super().__init__()
self.ex_iterable = ex_iterable
self.column_names = column_names

@property
def iter_arrow(self):
if self.ex_iterable.iter_arrow:
self.iter_arrow = self._iter_arrow
return self._iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
Expand Down Expand Up @@ -587,8 +596,11 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
super().__init__()
self.ex_iterables = ex_iterables
if all(ex_iterable.iter_arrow is not None for ex_iterable in ex_iterables):
self.iter_arrow = self._iter_arrow

@property
def iter_arrow(self):
if all(ex_iterable.iter_arrow is not None for ex_iterable in self.ex_iterables):
return self._iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = {
Expand Down Expand Up @@ -814,8 +826,11 @@ def __init__(
self.input_columns = input_columns
self.fn_kwargs = fn_kwargs or {}
self.formatting = formatting

@property
def iter_arrow(self):
if self.formatting and self.formatting.format_type == "arrow":
self.iter_arrow = self._iter_arrow
return self._iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
Expand Down Expand Up @@ -1006,8 +1021,11 @@ def __init__(
self.input_columns = input_columns
self.fn_kwargs = fn_kwargs or {}
self.formatting = formatting

@property
def iter_arrow(self):
if self.formatting and self.formatting.format_type == "arrow":
self.iter_arrow = self._iter_arrow
return self._iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
Expand Down Expand Up @@ -1292,8 +1310,11 @@ def __init__(
self.ex_iterable = ex_iterable
self.features = features
self.token_per_repo_id = token_per_repo_id

@property
def iter_arrow(self):
if self.ex_iterable.iter_arrow is not None:
self.iter_arrow = self._iter_arrow
return self._iter_arrow

def _init_state_dict(self) -> dict:
if not self._state_dict:
Expand Down Expand Up @@ -1580,6 +1601,7 @@ def __iter__(self):
format_dict = None

if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"):
assert self._state_dict is ex_iterable._state_dict
if ex_iterable.iter_arrow:
iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=1)
else:
Expand Down

0 comments on commit c35a036

Please sign in to comment.