Skip to content

Commit

Permalink
no-up
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652899920
  • Loading branch information
tfx-copybara committed Jul 16, 2024
1 parent 5b93b5d commit 64ce511
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 7 deletions.
30 changes: 23 additions & 7 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

_TELEMETRY_DESCRIPTORS = ['StatisticsGen']
STATS_DASHBOARD_LINK = 'stats_dashboard_link'
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split'


class Executor(base_beam_executor.BaseBeamExecutor):
Expand Down Expand Up @@ -132,13 +133,6 @@ def Do(

split_names = [split for split in splits if split not in exclude_splits]

# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)

statistics_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.STATISTICS_KEY]
)
Expand Down Expand Up @@ -169,6 +163,28 @@ def Do(
# json_utils
stats_options = options.StatsOptions.from_json(stats_options_json)

sample_rate_by_split_property = (
{split: stats_options.sample_rate for split in split_names}
if stats_options.sample_rate
else {}
)
# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)
continue
sample_rate_by_split_property[split] = sample_rate_by_split[split]

# Add sample_rate_by_split property to statistics artifact
# when sampling is set.
if sample_rate_by_split_property:
statistics_artifact.set_json_value_custom_property(
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME,
json_utils.dumps(sample_rate_by_split_property),
)

write_sharded_output = exec_properties.get(
standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False
)
Expand Down
98 changes: 98 additions & 0 deletions tfx/components/statistics_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,29 @@
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
},
{
'testcase_name': 'custom_split_uri',
'sharded_output': False,
'custom_split_uri': True,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
},
{
'testcase_name': 'sample_rate_by_split',
'sharded_output': False,
'custom_split_uri': False,
# set a higher sample rate since test data is small
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
'have_sample_rate_by_split_property': True,
},
{
'testcase_name': 'sample_rate_split_nonexist',
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': '{"test": 0.05}',
'have_sample_rate_by_split_property': False,
},
]
if tfdv.default_sharded_output_supported():
Expand All @@ -62,6 +66,7 @@
'sharded_output': True,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
})
_TEST_SPAN_NUMBER = 16000

Expand Down Expand Up @@ -96,6 +101,7 @@ def testDo(
sharded_output: bool,
custom_split_uri: bool,
sample_rate_by_split: str,
have_sample_rate_by_split_property: bool,
):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down Expand Up @@ -149,6 +155,10 @@ def testDo(
artifact_utils.encode_split_names(['train', 'eval']), stats.split_names)
self.assertEqual(
stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '')
self.assertEqual(
stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME),
have_sample_rate_by_split_property,
)
self.assertEqual(stats.span, _TEST_SPAN_NUMBER)

# Check statistics_gen outputs.
Expand Down Expand Up @@ -228,6 +238,94 @@ def testDoWithSchemaAndStatsOptions(self):
self._validate_stats_output(
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb'))

@parameterized.named_parameters(
{
'testcase_name': 'sample_rate_only',
'sample_rate': 0.2,
'sample_rate_by_split': 'null',
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
},
{
'testcase_name': 'sample_rate_by_split_only',
'sample_rate': None,
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.6},
},
{
'testcase_name': 'sample_rate_by_split_override',
'sample_rate': 0.2,
'sample_rate_by_split': '{"train": 0.4}',
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.2},
},
{
'testcase_name': 'sample_rate_by_split_invalid',
'sample_rate': 0.2,
'sample_rate_by_split': '{"test": 0.4}',
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
},
)
def testDoWithSamplingProperty(
self,
sample_rate,
sample_rate_by_split,
expected_sample_rate_by_split_property
):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata'
)
output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName,
)
fileio.makedirs(output_data_dir)

# Create input dict.
examples = standard_artifacts.Examples()
examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])

schema = standard_artifacts.Schema()
schema.uri = os.path.join(source_data_dir, 'schema_gen')

input_dict = {
standard_component_specs.EXAMPLES_KEY: [examples],
standard_component_specs.SCHEMA_KEY: [schema],
}

exec_properties = {
standard_component_specs.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(
sample_rate=sample_rate
).to_json(),
standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps([]),
standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: sample_rate_by_split,
}

# Create output dict.
stats = standard_artifacts.ExampleStatistics()
stats.uri = output_data_dir
output_dict = {
standard_component_specs.STATISTICS_KEY: [stats],
}

# Run executor.
stats_gen_executor = executor.Executor()
stats_gen_executor.Do(input_dict, output_dict, exec_properties)

self.assertEqual(
json_utils.loads(stats.get_json_value_custom_property(
executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME
)),
expected_sample_rate_by_split_property,
)

# Check statistics_gen outputs.
self._validate_stats_output(
os.path.join(stats.uri, 'Split-train', 'FeatureStats.pb')
)
self._validate_stats_output(
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb')
)

def testDoWithTwoSchemas(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down

0 comments on commit 64ce511

Please sign in to comment.