Skip to content

Commit

Permalink
no-up
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652899920
  • Loading branch information
tfx-copybara committed Jul 17, 2024
1 parent ef54551 commit 4878a75
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

_TELEMETRY_DESCRIPTORS = ['StatisticsGen']
STATS_DASHBOARD_LINK = 'stats_dashboard_link'
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split'


class Executor(base_beam_executor.BaseBeamExecutor):
Expand Down Expand Up @@ -132,13 +133,6 @@ def Do(

split_names = [split for split in splits if split not in exclude_splits]

# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)

statistics_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.STATISTICS_KEY]
)
Expand Down Expand Up @@ -169,6 +163,24 @@ def Do(
# json_utils
stats_options = options.StatsOptions.from_json(stats_options_json)

sample_rate_by_split_property = {
split: stats_options.sample_rate or 1.0 for split in split_names
}
for split in sample_rate_by_split:
# Check if sample_rate_by_split contains invalid split names
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)
continue
sample_rate_by_split_property[split] = sample_rate_by_split[split]

# Add sample_rate_by_split property to statistics artifact
statistics_artifact.set_json_value_custom_property(
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME,
json_utils.dumps(sample_rate_by_split_property),
)

write_sharded_output = exec_properties.get(
standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False
)
Expand Down
99 changes: 99 additions & 0 deletions tfx/components/statistics_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def testDo(
artifact_utils.encode_split_names(['train', 'eval']), stats.split_names)
self.assertEqual(
stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '')
self.assertEqual(
stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME),
True,
)
self.assertEqual(stats.span, _TEST_SPAN_NUMBER)

# Check statistics_gen outputs.
Expand Down Expand Up @@ -228,6 +232,101 @@ def testDoWithSchemaAndStatsOptions(self):
self._validate_stats_output(
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb'))

@parameterized.named_parameters(
{
'testcase_name': 'sample_rate_only',
'sample_rate': 0.2,
'sample_rate_by_split': 'null',
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
},
{
'testcase_name': 'sample_rate_by_split_only',
'sample_rate': None,
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.6},
},
{
'testcase_name': 'sample_rate_for_some_split_only',
'sample_rate': None,
'sample_rate_by_split': '{"train": 0.4}',
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 1.0},
},
{
'testcase_name': 'sample_rate_by_split_override',
'sample_rate': 0.2,
'sample_rate_by_split': '{"train": 0.4}',
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.2},
},
{
'testcase_name': 'sample_rate_by_split_invalid',
'sample_rate': 0.2,
'sample_rate_by_split': '{"test": 0.4}',
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
},
)
def testDoWithSamplingProperty(
self,
sample_rate,
sample_rate_by_split,
expected_sample_rate_by_split_property
):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata'
)
output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName,
)
fileio.makedirs(output_data_dir)

# Create input dict.
examples = standard_artifacts.Examples()
examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])

schema = standard_artifacts.Schema()
schema.uri = os.path.join(source_data_dir, 'schema_gen')

input_dict = {
standard_component_specs.EXAMPLES_KEY: [examples],
standard_component_specs.SCHEMA_KEY: [schema],
}

exec_properties = {
standard_component_specs.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(
sample_rate=sample_rate
).to_json(),
standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps([]),
standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: sample_rate_by_split,
}

# Create output dict.
stats = standard_artifacts.ExampleStatistics()
stats.uri = output_data_dir
output_dict = {
standard_component_specs.STATISTICS_KEY: [stats],
}

# Run executor.
stats_gen_executor = executor.Executor()
stats_gen_executor.Do(input_dict, output_dict, exec_properties)

# Check statistics artifact sample_rate_by_split property.
self.assertEqual(
json_utils.loads(stats.get_json_value_custom_property(
executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME
)),
expected_sample_rate_by_split_property,
)

# Check statistics_gen outputs.
self._validate_stats_output(
os.path.join(stats.uri, 'Split-train', 'FeatureStats.pb')
)
self._validate_stats_output(
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb')
)

def testDoWithTwoSchemas(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down

0 comments on commit 4878a75

Please sign in to comment.