Skip to content

Commit

Permalink
Make configs call super post_init in packaged modules
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed May 22, 2024
1 parent 60d21ef commit 815ac93
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ArrowConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Arrow(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ArrowConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/audiofolder/audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class AudioFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class AudioFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Audio
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CsvConfig(datasets.BuilderConfig):
date_format: Optional[str] = None

def __post_init__(self):
super().__post_init__()
if self.delimiter is not None:
self.sep = self.delimiter
if self.column_names is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class FolderBasedBuilderConfig(datasets.BuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class GeneratorConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
assert self.generator is not None, "generator must be specified"
super().__post_init__()
if self.generator is None:
raise ValueError("generator must be specified")

if self.gen_kwargs is None:
self.gen_kwargs = {}
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/imagefolder/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class ImageFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class ImageFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Image
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class JsonConfig(datasets.BuilderConfig):
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None

def __post_init__(self):
super().__post_init__()


class Json(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/pandas/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class PandasConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Pandas(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = PandasConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ParquetConfig(datasets.BuilderConfig):
columns: Optional[List[str]] = None
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Parquet(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ParquetConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class SparkConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]):
df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}")
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SqlConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()
if self.sql is None:
raise ValueError("sql must be specified")
if self.con is None:
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TextConfig(datasets.BuilderConfig):
sample_by: str = "line"

def __post_init__(self, errors):
super().__post_init__()
if errors != "deprecated":
warnings.warn(
"'errors' was deprecated in favor of 'encoding_errors' in version 2.14.0 and will be removed in 3.0.0.\n"
Expand Down

0 comments on commit 815ac93

Please sign in to comment.