Skip to content

Commit

Permalink
filter_files_by_extension function
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
  • Loading branch information
sarahyurick committed Oct 22, 2024
1 parent 4ad1a4d commit 64788e5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
5 changes: 3 additions & 2 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List, Optional, Union

import dask.dataframe as dd
Expand Down Expand Up @@ -196,8 +197,8 @@ def _read_json_or_parquet(
file_ext = "." + file_type

if isinstance(input_files, list):
# List of jsonl or parquet files
if all(f.endswith(file_ext) for f in input_files):
# List of files
if all(os.path.isfile(f) for f in input_files):
raw_data = read_data(
input_files,
file_type=file_type,
Expand Down
17 changes: 17 additions & 0 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,16 @@ def read_pandas_pickle(file, add_filename=False) -> pd.DataFrame:
return pd.read_pickle(file)


def filter_files_by_extension(files_list, file_ext):
filtered_files = []
for file in files_list:
if file.endswith(file_ext):
filtered_files.append(file)
else:
warnings.warn(f"Skipping read for file: {file}")
return filtered_files


def read_data(
input_files,
file_type: str = "pickle",
Expand Down Expand Up @@ -391,15 +401,20 @@ def read_data(
df = df.to_backend("cudf")

elif file_type in ["json", "jsonl", "parquet"]:
file_ext = "." + file_type
input_files = filter_files_by_extension(input_files, file_ext)
print(f"Reading {len(input_files)} files", flush=True)
input_files = sorted(input_files)

if files_per_partition > 1:
input_files = [
input_files[i : i + files_per_partition]
for i in range(0, len(input_files), files_per_partition)
]

else:
input_files = [[file] for file in input_files]

return dd.from_map(
read_single_partition,
input_files,
Expand All @@ -409,8 +424,10 @@ def read_data(
input_meta=input_meta,
enforce_metadata=False,
)

else:
raise RuntimeError("Could not read data, please check file type")

return df


Expand Down

0 comments on commit 64788e5

Please sign in to comment.