From 4878a754a8ca8314300e68e8812eb055f7d53683 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Tue, 16 Jul 2024 10:41:30 -0700 Subject: [PATCH] no-up PiperOrigin-RevId: 652899920 --- tfx/components/statistics_gen/executor.py | 26 +++-- .../statistics_gen/executor_test.py | 99 +++++++++++++++++++ 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/tfx/components/statistics_gen/executor.py b/tfx/components/statistics_gen/executor.py index ee9f43dda8..20f4f49f77 100644 --- a/tfx/components/statistics_gen/executor.py +++ b/tfx/components/statistics_gen/executor.py @@ -35,6 +35,7 @@ _TELEMETRY_DESCRIPTORS = ['StatisticsGen'] STATS_DASHBOARD_LINK = 'stats_dashboard_link' +SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split' class Executor(base_beam_executor.BaseBeamExecutor): @@ -132,13 +133,6 @@ def Do( split_names = [split for split in splits if split not in exclude_splits] - # Check if sample_rate_by_split contains invalid split names - for split in sample_rate_by_split: - if split not in split_names: - logging.error( - 'Split %s provided in sample_rate_by_split is not valid.', split - ) - statistics_artifact = artifact_utils.get_single_instance( output_dict[standard_component_specs.STATISTICS_KEY] ) @@ -169,6 +163,24 @@ def Do( # json_utils stats_options = options.StatsOptions.from_json(stats_options_json) + sample_rate_by_split_property = { + split: stats_options.sample_rate or 1.0 for split in split_names + } + for split in sample_rate_by_split: + # Check if sample_rate_by_split contains invalid split names + if split not in split_names: + logging.error( + 'Split %s provided in sample_rate_by_split is not valid.', split + ) + continue + sample_rate_by_split_property[split] = sample_rate_by_split[split] + + # Add sample_rate_by_split property to statistics artifact + statistics_artifact.set_json_value_custom_property( + SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME, + json_utils.dumps(sample_rate_by_split_property), + ) + write_sharded_output = exec_properties.get( standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False ) diff --git a/tfx/components/statistics_gen/executor_test.py b/tfx/components/statistics_gen/executor_test.py index 3bfab22a6a..d55abaa4a0 100644 --- a/tfx/components/statistics_gen/executor_test.py +++ b/tfx/components/statistics_gen/executor_test.py @@ -149,6 +149,10 @@ def testDo( artifact_utils.encode_split_names(['train', 'eval']), stats.split_names) self.assertEqual( stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '') + self.assertEqual( + stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME), + True, + ) self.assertEqual(stats.span, _TEST_SPAN_NUMBER) # Check statistics_gen outputs. @@ -228,6 +232,101 @@ def testDoWithSchemaAndStatsOptions(self): self._validate_stats_output( os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb')) + @parameterized.named_parameters( + { + 'testcase_name': 'sample_rate_only', + 'sample_rate': 0.2, + 'sample_rate_by_split': 'null', + 'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2}, + }, + { + 'testcase_name': 'sample_rate_by_split_only', + 'sample_rate': None, + 'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}', + 'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.6}, + }, + { + 'testcase_name': 'sample_rate_for_some_split_only', + 'sample_rate': None, + 'sample_rate_by_split': '{"train": 0.4}', + 'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 1.0}, + }, + { + 'testcase_name': 'sample_rate_by_split_override', + 'sample_rate': 0.2, + 'sample_rate_by_split': '{"train": 0.4}', + 'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.2}, + }, + { + 'testcase_name': 'sample_rate_by_split_invalid', + 'sample_rate': 0.2, + 'sample_rate_by_split': '{"test": 0.4}', + 'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2}, + }, + ) + def testDoWithSamplingProperty( + self, + sample_rate, + sample_rate_by_split, + expected_sample_rate_by_split_property + ): + source_data_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), 'testdata' + ) + output_data_dir = os.path.join( + os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), + self._testMethodName, + ) + fileio.makedirs(output_data_dir) + + # Create input dict. + examples = standard_artifacts.Examples() + examples.uri = os.path.join(source_data_dir, 'csv_example_gen') + examples.split_names = artifact_utils.encode_split_names(['train', 'eval']) + + schema = standard_artifacts.Schema() + schema.uri = os.path.join(source_data_dir, 'schema_gen') + + input_dict = { + standard_component_specs.EXAMPLES_KEY: [examples], + standard_component_specs.SCHEMA_KEY: [schema], + } + + exec_properties = { + standard_component_specs.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions( + sample_rate=sample_rate + ).to_json(), + standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps([]), + standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: sample_rate_by_split, + } + + # Create output dict. + stats = standard_artifacts.ExampleStatistics() + stats.uri = output_data_dir + output_dict = { + standard_component_specs.STATISTICS_KEY: [stats], + } + + # Run executor. + stats_gen_executor = executor.Executor() + stats_gen_executor.Do(input_dict, output_dict, exec_properties) + + # Check statistics artifact sample_rate_by_split property. + self.assertEqual( + json_utils.loads(stats.get_json_value_custom_property( + executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME + )), + expected_sample_rate_by_split_property, + ) + + # Check statistics_gen outputs. + self._validate_stats_output( + os.path.join(stats.uri, 'Split-train', 'FeatureStats.pb') + ) + self._validate_stats_output( + os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb') + ) + def testDoWithTwoSchemas(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata')