diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 25b27d091a3..7de112345df 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3489,7 +3489,9 @@ def init_buffer_and_writer(): else: buf_writer = None logger.info(f"Caching processed dataset at {cache_file_name}") - tmp_file = tempfile.NamedTemporaryFile("wb", dir=os.path.dirname(cache_file_name), delete=False) + cache_dir = os.path.dirname(cache_file_name) + os.makedirs(cache_dir, exist_ok=True) + tmp_file = tempfile.NamedTemporaryFile("wb", dir=cache_dir, delete=False) writer = ArrowWriter( features=writer_features, path=tmp_file.name, @@ -4082,7 +4084,9 @@ def _select_with_indices_mapping( else: buf_writer = None logger.info(f"Caching indices mapping at {indices_cache_file_name}") - tmp_file = tempfile.NamedTemporaryFile("wb", dir=os.path.dirname(indices_cache_file_name), delete=False) + cache_dir = os.path.dirname(indices_cache_file_name) + os.makedirs(cache_dir, exist_ok=True) + tmp_file = tempfile.NamedTemporaryFile("wb", dir=cache_dir, delete=False) writer = ArrowWriter( path=tmp_file.name, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint, unit="indices" )