Skip to content

Commit

Permalink
updates to support .from_jsonl() and loeading TaskConfigs externally
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Jan 19, 2024
1 parent 955b587 commit 7dc2115
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
57 changes: 48 additions & 9 deletions tableshift/core/features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import List, Any, Sequence, Optional, Mapping, Tuple, Union, Dict
Expand Down Expand Up @@ -72,6 +73,19 @@ def get_dtype(dtype):
raise ValueError(f"unknown dtype: {dtype}")


def cast_number(number_str: str):
try:
# Try to convert to integer
return int(number_str)
except ValueError:
try:
# If integer conversion fails, try to convert to float
return float(number_str)
except ValueError:
# If both conversions fail, return the original string
return number_str


@dataclass(frozen=True)
class Feature:
name: str
Expand Down Expand Up @@ -141,21 +155,46 @@ def from_dataframe(cls, df: pd.DataFrame, target_colname: str,

return cls(features=features, **kwargs)

def to_jsonl(self, file: str):
with open(file, "w") as f:
for feature in self.features:
output_dict = copy.deepcopy(feature.__dict__)
kind = output_dict.pop('kind')
kind_str = getattr(kind, "name", kind.__name__)
output_dict['kind'] = kind_str
f.write(json.dumps(output_dict) + "\n")
def to_jsonl(self, file):
"""Write to a jsonl file. File can either be a sting path to a file, or a handle-like object."""
if isinstance(file, str):
# Case: file is a string; open handle to a file object at this path.
with open(file, "w") as f:
for feature in self.features:
output_dict = copy.deepcopy(feature.__dict__)
kind = output_dict.pop('kind')
kind_str = getattr(kind, "name", kind.__name__)
output_dict['kind'] = kind_str
f.write(json.dumps(output_dict) + "\n")
else:
# Case: file is a handle; try to write to it directly.
try:
for feature in self.features:
output_dict = copy.deepcopy(feature.__dict__)
kind = output_dict.pop('kind')
kind_str = getattr(kind, "name", kind.__name__)
output_dict['kind'] = kind_str
line = json.dumps(output_dict) + "\n"
file.write(line.encode())
except Exception as e:
logging.error(f"error writing to file {file} of type {type(file)}")
raise e

@classmethod
def from_jsonl(cls, file: str):
def from_jsonl(cls, file: str, auto_cast_value_mappings: bool = False):
assert os.path.exists(file), f"file {file} does not exist."
with open(file, "r") as f:
lines = f.readlines()
feature_dicts = [json.loads(l) for l in lines]

# JSON parsing automatically casts any int/float keys to string; if auto_cast_value_mapping
# is True, we attempt to recover these.
if auto_cast_value_mappings:
for feature_dict in feature_dicts:
if feature_dict['value_mapping']:
feature_dict['value_mapping'] = {cast_number(k): cast_number(v) for k, v in
feature_dict['value_mapping'].items()}

# For each element in feature_dicts, create the actual class object
# corresponding to the feature kind, from its string representation.

Expand Down
8 changes: 8 additions & 0 deletions tableshift/core/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Optional, Dict, Any, Union

from tableshift.core.tasks import _TASK_REGISTRY
from tableshift.core.data_source import DataSource
from tableshift import exceptions
from tableshift.configs.experiment_defaults import DEFAULT_RANDOM_STATE
from tableshift.configs.benchmark_configs import BENCHMARK_CONFIGS
Expand All @@ -15,6 +17,12 @@
**NON_BENCHMARK_CONFIGS
}

def get_data_source(name:str, cache_dir:str, download=True, **kwargs) -> DataSource:
"""Get the data source for a dataset, if it exists in the task registry."""
if name not in _TASK_REGISTRY:
raise ValueError(f"Dataset '{name}' not in available task registry: {sorted(_TASK_REGISTRY.keys())}")
task_config = _TASK_REGISTRY[name]
return task_config.data_source_cls(cache_dir=cache_dir, download=download, **kwargs)

def get_dataset(name: str, cache_dir: str = "tmp",
preprocessor_config: Optional[
Expand Down
3 changes: 3 additions & 0 deletions tableshift/core/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def features(self) -> List[str]:
"""Fetch a list of the feature names."""
raise

def _get_info(self) -> Dict[str, Any]:
raise

def _check_split(self, split):
"""Check that a split name is valid."""
assert self._is_valid_split(split), \
Expand Down
5 changes: 3 additions & 2 deletions tableshift/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from dataclasses import dataclass
from typing import Any
from typing import Any, Dict
from .data_source import *
from .features import FeatureList

Expand All @@ -25,10 +25,11 @@ class TaskConfig:
feature_list: FeatureList



# Mapping of task names to their configs. An arbitrary number of tasks
# can be created from a single data source, by specifying different
# preprocess_fn and features.
_TASK_REGISTRY = {
_TASK_REGISTRY: Dict[str, TaskConfig] = {
"acsincome":
TaskConfig(ACSDataSource,
ACS_INCOME_FEATURES + ACS_SHARED_FEATURES),
Expand Down

0 comments on commit 7dc2115

Please sign in to comment.