Skip to content

Commit

Permalink
Automatically create cache_dir from cache_file_name (#7096)
Browse files Browse the repository at this point in the history
Automatically create cache_dir from cache_file_name

You get a pretty unhelpful error message when specifying a cache_file_name in a directory that doesn't exist, e.g. cache_file_name="./cache/data.map"

FileNotFoundError: [Errno 2] No such file or directory: '/.../cache/tmp48r61siw'
  • Loading branch information
ringohoffman authored Aug 15, 2024
1 parent 69d9f45 commit 93dc735
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3489,7 +3489,9 @@ def init_buffer_and_writer():
else:
buf_writer = None
logger.info(f"Caching processed dataset at {cache_file_name}")
tmp_file = tempfile.NamedTemporaryFile("wb", dir=os.path.dirname(cache_file_name), delete=False)
cache_dir = os.path.dirname(cache_file_name)
os.makedirs(cache_dir, exist_ok=True)
tmp_file = tempfile.NamedTemporaryFile("wb", dir=cache_dir, delete=False)
writer = ArrowWriter(
features=writer_features,
path=tmp_file.name,
Expand Down Expand Up @@ -4082,7 +4084,9 @@ def _select_with_indices_mapping(
else:
buf_writer = None
logger.info(f"Caching indices mapping at {indices_cache_file_name}")
tmp_file = tempfile.NamedTemporaryFile("wb", dir=os.path.dirname(indices_cache_file_name), delete=False)
cache_dir = os.path.dirname(indices_cache_file_name)
os.makedirs(cache_dir, exist_ok=True)
tmp_file = tempfile.NamedTemporaryFile("wb", dir=cache_dir, delete=False)
writer = ArrowWriter(
path=tmp_file.name, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint, unit="indices"
)
Expand Down

0 comments on commit 93dc735

Please sign in to comment.