diff --git a/tfx/components/statistics_gen/executor.py b/tfx/components/statistics_gen/executor.py index ee9f43dda8c..31b6c62f6bd 100644 --- a/tfx/components/statistics_gen/executor.py +++ b/tfx/components/statistics_gen/executor.py @@ -35,6 +35,7 @@ _TELEMETRY_DESCRIPTORS = ['StatisticsGen'] STATS_DASHBOARD_LINK = 'stats_dashboard_link' +SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split' class Executor(base_beam_executor.BaseBeamExecutor): @@ -132,13 +133,6 @@ def Do( split_names = [split for split in splits if split not in exclude_splits] - # Check if sample_rate_by_split contains invalid split names - for split in sample_rate_by_split: - if split not in split_names: - logging.error( - 'Split %s provided in sample_rate_by_split is not valid.', split - ) - statistics_artifact = artifact_utils.get_single_instance( output_dict[standard_component_specs.STATISTICS_KEY] ) @@ -169,6 +163,27 @@ def Do( # json_utils stats_options = options.StatsOptions.from_json(stats_options_json) + sample_rate_by_split_property = ( + {split: stats_options.sample_rate for split in split_names} + if stats_options.sample_rate + else {} + ) + # Check if sample_rate_by_split contains invalid split names + for split in sample_rate_by_split: + if split not in split_names: + logging.error( + 'Split %s provided in sample_rate_by_split is not valid.', split + ) + continue + sample_rate_by_split_property[split] = sample_rate_by_split[split] + + if sample_rate_by_split_property: + # Add sample_rate_by_split property to statistics artifact + statistics_artifact.set_json_value_custom_property( + SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME, + json_utils.dumps(sample_rate_by_split_property), + ) + write_sharded_output = exec_properties.get( standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False ) diff --git a/tfx/components/statistics_gen/executor_test.py b/tfx/components/statistics_gen/executor_test.py index 3bfab22a6a4..5a1d460571d 100644 --- a/tfx/components/statistics_gen/executor_test.py +++ b/tfx/components/statistics_gen/executor_test.py @@ -35,12 +35,14 @@ 'sharded_output': False, 'custom_split_uri': False, 'sample_rate_by_split': 'null', + 'have_sample_rate_by_split_property': False, }, { 'testcase_name': 'custom_split_uri', 'sharded_output': False, 'custom_split_uri': True, 'sample_rate_by_split': 'null', + 'have_sample_rate_by_split_property': False, }, { 'testcase_name': 'sample_rate_by_split', @@ -48,12 +50,14 @@ 'custom_split_uri': False, # set a higher sample rate since test data is small 'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}', + 'have_sample_rate_by_split_property': True, }, { 'testcase_name': 'sample_rate_split_nonexist', 'sharded_output': False, 'custom_split_uri': False, 'sample_rate_by_split': '{"test": 0.05}', + 'have_sample_rate_by_split_property': False, }, ] if tfdv.default_sharded_output_supported(): @@ -62,6 +66,7 @@ 'sharded_output': True, 'custom_split_uri': False, 'sample_rate_by_split': 'null', + 'have_sample_rate_by_split_property': False, }) _TEST_SPAN_NUMBER = 16000 @@ -96,6 +101,7 @@ def testDo( sharded_output: bool, custom_split_uri: bool, sample_rate_by_split: str, + have_sample_rate_by_split_property: bool, ): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') @@ -149,6 +155,21 @@ def testDo( artifact_utils.encode_split_names(['train', 'eval']), stats.split_names) self.assertEqual( stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '') + # if standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: + # self.assertEqual( + # json_utils.loads( + # stats.get_string_custom_property( + # stats.get_json_value_custom_property( + # executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME + # ) + # ) + # ), + # json_utils.loads(expected_sample_rate_property), + # ) + self.assertEqual( + stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME), + have_sample_rate_by_split_property, + ) self.assertEqual(stats.span, _TEST_SPAN_NUMBER) # Check statistics_gen outputs.