Skip to content

Commit

Permalink
no-up
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652899920
  • Loading branch information
tfx-copybara committed Jul 16, 2024
1 parent 5b93b5d commit 553108f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
29 changes: 22 additions & 7 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

_TELEMETRY_DESCRIPTORS = ['StatisticsGen']
STATS_DASHBOARD_LINK = 'stats_dashboard_link'
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split'


class Executor(base_beam_executor.BaseBeamExecutor):
Expand Down Expand Up @@ -132,13 +133,6 @@ def Do(

split_names = [split for split in splits if split not in exclude_splits]

# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)

statistics_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.STATISTICS_KEY]
)
Expand Down Expand Up @@ -169,6 +163,27 @@ def Do(
# json_utils
stats_options = options.StatsOptions.from_json(stats_options_json)

sample_rate_by_split_property = (
{split: stats_options.sample_rate for split in split_names}
if stats_options.sample_rate
else {}
)
# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)
continue
sample_rate_by_split_property[split] = sample_rate_by_split[split]

if sample_rate_by_split_property:
# Add sample_rate_by_split property to statistics artifact
statistics_artifact.set_json_value_custom_property(
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME,
json_utils.dumps(sample_rate_by_split_property),
)

write_sharded_output = exec_properties.get(
standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False
)
Expand Down
21 changes: 21 additions & 0 deletions tfx/components/statistics_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,29 @@
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
},
{
'testcase_name': 'custom_split_uri',
'sharded_output': False,
'custom_split_uri': True,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
},
{
'testcase_name': 'sample_rate_by_split',
'sharded_output': False,
'custom_split_uri': False,
# set a higher sample rate since test data is small
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
'have_sample_rate_by_split_property': True,
},
{
'testcase_name': 'sample_rate_split_nonexist',
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': '{"test": 0.05}',
'have_sample_rate_by_split_property': False,
},
]
if tfdv.default_sharded_output_supported():
Expand All @@ -62,6 +66,7 @@
'sharded_output': True,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
'have_sample_rate_by_split_property': False,
})
_TEST_SPAN_NUMBER = 16000

Expand Down Expand Up @@ -96,6 +101,7 @@ def testDo(
sharded_output: bool,
custom_split_uri: bool,
sample_rate_by_split: str,
have_sample_rate_by_split_property: bool,
):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down Expand Up @@ -149,6 +155,21 @@ def testDo(
artifact_utils.encode_split_names(['train', 'eval']), stats.split_names)
self.assertEqual(
stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '')
# if standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY:
# self.assertEqual(
# json_utils.loads(
# stats.get_string_custom_property(
# stats.get_json_value_custom_property(
# executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME
# )
# )
# ),
# json_utils.loads(expected_sample_rate_property),
# )
self.assertEqual(
stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME),
have_sample_rate_by_split_property,
)
self.assertEqual(stats.span, _TEST_SPAN_NUMBER)

# Check statistics_gen outputs.
Expand Down

0 comments on commit 553108f

Please sign in to comment.