Skip to content

Commit

Permalink
Preserve JSON column order and support list of strings field (#6914)
Browse files Browse the repository at this point in the history
* Test JSON generates tables with sorted columns

* Test JSON generates tables for multiple JSON structures

* Fix style

* Make JSON builder use pandas read_json for JSON files
  • Loading branch information
albertvillanova authored and lhoestq committed May 29, 2024
1 parent eafed0d commit d374880
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 41 deletions.
66 changes: 26 additions & 40 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
logger = datasets.utils.logging.get_logger(__name__)


def ujson_dumps(*args, **kwargs):
try:
return pd.io.json.ujson_dumps(*args, **kwargs)
except AttributeError:
# Before pandas-2.2.0, ujson_dumps was renamed to dumps: import ujson_dumps as dumps
return pd.io.json.dumps(*args, **kwargs)


def ujson_loads(*args, **kwargs):
try:
return pd.io.json.ujson_loads(*args, **kwargs)
Expand Down Expand Up @@ -85,21 +93,16 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# If the file is one json object and if we need to look at the list of items in one specific field
# If the file is one json object and if we need to look at the items in one specific field
if self.config.field is not None:
with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
dataset = ujson_loads(f.read())

# We keep only the field we are interested in
dataset = dataset[self.config.field]

# We accept two format: a list of dicts or a dict of lists
if isinstance(dataset, (list, tuple)):
keys = set().union(*[row.keys() for row in dataset])
mapping = {col: [row.get(col) for row in dataset] for col in keys}
else:
mapping = dataset
pa_table = pa.Table.from_pydict(mapping)
df = pd.read_json(io.StringIO(ujson_dumps(dataset)), dtype_backend="pyarrow")
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)
yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
Expand Down Expand Up @@ -150,39 +153,22 @@ def _generate_tables(self, files):
with open(
file, encoding=self.config.encoding, errors=self.config.encoding_errors
) as f:
dataset = ujson_loads(f.read())
df = pd.read_json(f, dtype_backend="pyarrow")
except ValueError:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
# If possible, parse the file as a list of json objects/strings and exit the loop
if isinstance(dataset, list): # list is the only sequence type supported in JSON
try:
if dataset and isinstance(dataset[0], str):
pa_table_names = (
list(self.config.features)
if self.config.features is not None
else ["text"]
)
pa_table = pa.Table.from_arrays([pa.array(dataset)], names=pa_table_names)
else:
keys = set().union(*[row.keys() for row in dataset])
mapping = {col: [row.get(col) for row in dataset] for col in keys}
pa_table = pa.Table.from_pydict(mapping)
except (pa.ArrowInvalid, AttributeError) as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise ValueError(f"Not able to read records in the JSON file at {file}.") from None
yield file_idx, self._cast_table(pa_table)
break
else:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
)
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield file_idx, self._cast_table(pa_table)
break
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1
97 changes: 96 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,85 @@ def json_file_with_list_of_dicts_field(tmp_path):
return str(filename)


@pytest.fixture
def json_file_with_list_of_strings_field(tmp_path):
path = tmp_path / "file.json"
data = textwrap.dedent(
"""\
{
"field1": 1,
"field2": "aabb",
"field3": [
"First text.",
"Second text.",
"Third text."
]
}
"""
)
with open(path, "w") as f:
f.write(data)
return str(path)


@pytest.fixture
def json_file_with_dict_of_lists_field(tmp_path):
path = tmp_path / "file.json"
data = textwrap.dedent(
"""\
{
"field1": 1,
"field2": "aabb",
"field3": {
"col_1": [-1, 1, 10],
"col_2": [null, 2, 20]
}
}
"""
)
with open(path, "w") as f:
f.write(data)
return str(path)


@pytest.fixture
def json_file_with_list_of_dicts_with_sorted_columns(tmp_path):
path = tmp_path / "file.json"
data = textwrap.dedent(
"""\
[
{"ID": 0, "Language": "Language-0", "Topic": "Topic-0"},
{"ID": 1, "Language": "Language-1", "Topic": "Topic-1"},
{"ID": 2, "Language": "Language-2", "Topic": "Topic-2"}
]
"""
)
with open(path, "w") as f:
f.write(data)
return str(path)


@pytest.fixture
def json_file_with_list_of_dicts_with_sorted_columns_field(tmp_path):
path = tmp_path / "file.json"
data = textwrap.dedent(
"""\
{
"field1": 1,
"field2": "aabb",
"field3": [
{"ID": 0, "Language": "Language-0", "Topic": "Topic-0"},
{"ID": 1, "Language": "Language-1", "Topic": "Topic-1"},
{"ID": 2, "Language": "Language-2", "Topic": "Topic-2"}
]
}
"""
)
with open(path, "w") as f:
f.write(data)
return str(path)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
Expand All @@ -100,13 +179,15 @@ def json_file_with_list_of_dicts_field(tmp_path):
("json_file_with_list_of_dicts", {}),
("json_file_with_list_of_dicts_field", {"field": "field3"}),
("json_file_with_list_of_strings", {}),
("json_file_with_list_of_strings_field", {"field": "field3"}),
("json_file_with_dict_of_lists_field", {"field": "field3"}),
],
)
def test_json_generate_tables(file_fixture, config_kwargs, request):
json = Json(**config_kwargs)
generator = json._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
if file_fixture == "json_file_with_list_of_strings":
if "list_of_strings" in file_fixture:
expected = {"text": ["First text.", "Second text.", "Third text."]}
else:
expected = {"col_1": [-1, 1, 10], "col_2": [None, 2, 20]}
Expand Down Expand Up @@ -140,3 +221,17 @@ def test_json_generate_tables_with_missing_features(file_fixture, config_kwargs,
generator = json._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.to_pydict() == {"col_1": [-1, 1, 10], "col_2": [None, 2, 20], "missing_col": [None, None, None]}


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
("json_file_with_list_of_dicts_with_sorted_columns", {}),
("json_file_with_list_of_dicts_with_sorted_columns_field", {"field": "field3"}),
],
)
def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, request):
builder = Json(**config_kwargs)
generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.column_names == ["ID", "Language", "Topic"]

0 comments on commit d374880

Please sign in to comment.