From be39f17ac93920156408bd39a93cd0ca7876c135 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 12 Jun 2024 19:14:43 -0700 Subject: [PATCH] Apply the `ph.make_proto` optimization to execution parameters with `use_proto=True`. PiperOrigin-RevId: 642815569 --- .../experimental/core/task_gen_utils_test.py | 15 +++++ .../core/testing/test_async_pipeline.py | 9 +++ .../portable/inputs_utils_test.py | 63 +++++++++++++++++++ tfx/types/component_spec.py | 22 ++++++- tfx/types/component_spec_test.py | 27 +++++--- 5 files changed, 126 insertions(+), 10 deletions(-) diff --git a/tfx/orchestration/experimental/core/task_gen_utils_test.py b/tfx/orchestration/experimental/core/task_gen_utils_test.py index 689bd2eb56..cff01b6740 100644 --- a/tfx/orchestration/experimental/core/task_gen_utils_test.py +++ b/tfx/orchestration/experimental/core/task_gen_utils_test.py @@ -473,6 +473,21 @@ def test_generate_resolved_info_with_dynamic_exec_prop(self): resolved_info.input_and_params[0].exec_properties['input_str'], ) + def test_generate_resolved_info_with_ph_exec_parameter(self): + otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1) + otu.fake_component_output(self._mlmd_connection, self._transform) + resolved_info = task_gen_utils.generate_resolved_info( + self._mlmd_connection_manager, + node_proto_view.get_view(self._trainer), + self._pipeline, + ) + self.assertProtoEquals( + """ + splits: "train" + """, + resolved_info.input_and_params[0].exec_properties['train_args'], + ) + @parameterized.named_parameters( dict( testcase_name='per_execution_idx_latest', diff --git a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py index 61279a0880..8c2aac7d90 100644 --- a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py +++ b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py @@ -20,7 +20,9 @@ from tfx.dsl.component.experimental.decorators import component from tfx.dsl.control_flow import for_each from tfx.dsl.input_resolution.canned_resolver_functions import latest_created +from tfx.dsl.placeholder import placeholder as ph from tfx.orchestration import pipeline as pipeline_lib +from tfx.proto import trainer_pb2 from tfx.proto.orchestration import pipeline_pb2 from tfx.types import standard_artifacts @@ -82,5 +84,12 @@ def create_pipeline() -> pipeline_pb2.Pipeline: assert trainer.node_info.id == 'my_trainer' for value in trainer.inputs.inputs.values(): value.min_count = 1 + train_args_proto = trainer_pb2.TrainArgs(splits=['train']) + train_args = ph.make_proto(train_args_proto) + trainer.parameters.parameters['train_args'].CopyFrom( + pipeline_pb2.Value( + placeholder=train_args.encode() + ) + ) return compiled_pipeline diff --git a/tfx/orchestration/portable/inputs_utils_test.py b/tfx/orchestration/portable/inputs_utils_test.py index 8e61c45902..326bae1ed4 100644 --- a/tfx/orchestration/portable/inputs_utils_test.py +++ b/tfx/orchestration/portable/inputs_utils_test.py @@ -385,6 +385,69 @@ def test_resolve_dynamic_parameters(self): dynamic_parameters, placeholder_utils.ResolutionContext() ) + def test_resolve_ph_execution_parameters(self): + execution_parameters = pipeline_pb2.NodeParameters() + text_format.Parse( + r""" + parameters: { + key: "train_args" + value: { + placeholder: { + operator: { + proto_op: { + expression: { + operator: { + make_proto_op: { + base: { + type_url: "type.googleapis.com/tensorflow.service.TrainArgs" + value: "\n\005train" + } + file_descriptors: { + file: { + name: "third_party/tfx/trainer.proto" + package: "tensorflow.service" + message_type { + name: "TrainArgs" + field { + name: "splits" + number: 1 + label: LABEL_REPEATED + type: TYPE_STRING + } + } + syntax: "proto3" + } + } + } + } + } + } + } + } + } + } + """, + execution_parameters, + ) + test_artifact = types.standard_artifacts.String() + test_artifact.uri = self.create_tempfile().full_path + test_artifact.value = 'testvalue' + input_dict = {'_test_placeholder': [test_artifact]} + exec_params_resolved = inputs_utils.resolve_dynamic_parameters( + execution_parameters, + placeholder_utils.ResolutionContext( + exec_info=data_types.ExecutionInfo( + input_dict=input_dict, pipeline_run_id='testrunid' + ) + ), + ) + self.assertProtoEquals( + """ + splits: "train" + """, + exec_params_resolved['train_args'], + ) + if __name__ == '__main__': tf.test.main() diff --git a/tfx/types/component_spec.py b/tfx/types/component_spec.py index 6abaf1a6db..16c0be2634 100644 --- a/tfx/types/component_spec.py +++ b/tfx/types/component_spec.py @@ -16,7 +16,7 @@ import copy import inspect import itertools -from typing import Any, Dict, List, Mapping, Optional, Type, cast +from typing import Any, cast, Dict, List, Mapping, Optional, Type from tfx.dsl.component.experimental.json_compat import check_strict_json_compat from tfx.dsl.placeholder import placeholder @@ -31,6 +31,21 @@ # Use Any to avoid cyclic import. _BaseNode = Any +# Execution parameters that have `use_proto=True` but cannot be optimized with +# Placeholder ph.make_proto. +# TODO(b/350820714): Placeholder needs to be supported at runtime so that +# TensorflowTrainerConfig placeholder can be used to create the Trainer and +# Tuner jobs. +# TODO(b/349459258): ExampleDiff executor needs to be updated to support +# placeholder proto fields not being present. +# TODO(b/352623284); DistributionValidator test needs to be updated to +# support placeholder proto. +_MAKE_PROTO_EXEMPT_EXEC_PARAMETERS = [ + 'tensorflow_trainer', + 'example_diff_config', + 'default_slice_config', +] + def _is_runtime_param(data: Any) -> bool: return data.__class__.__name__ == 'RuntimeParameter' @@ -229,11 +244,16 @@ def _parse_parameters(self, raw_args: Mapping[str, Any]): if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) # pytype: disable=not-supported-yet and value and not _is_runtime_param(value)) and not isinstance( value, placeholder.Placeholder): + # If the parameter is defined with use_proto=True, convert the value to + # proto from dict or json string if necessary before creating the proto + # placeholder. if arg.use_proto: if isinstance(value, dict): value = proto_utils.dict_to_proto(value, arg.type()) elif isinstance(value, str): value = proto_utils.json_to_proto(value, arg.type()) + if arg_name not in _MAKE_PROTO_EXEMPT_EXEC_PARAMETERS: + value = placeholder.make_proto(value) else: # Create deterministic json string as it will be stored in metadata # for cache check. diff --git a/tfx/types/component_spec_test.py b/tfx/types/component_spec_test.py index f1d3a3bfcd..c82b0f48ad 100644 --- a/tfx/types/component_spec_test.py +++ b/tfx/types/component_spec_test.py @@ -19,8 +19,10 @@ import unittest import tensorflow as tf +from tfx.dsl.compiler import placeholder_utils from tfx.dsl.components.base.testing import test_node from tfx.dsl.placeholder import placeholder +from tfx.orchestration.portable import data_types from tfx.proto import example_gen_pb2 from tfx.types import artifact from tfx.types import channel @@ -32,7 +34,6 @@ from tfx.utils import proto_utils from google.protobuf import json_format -from google.protobuf import text_format class _InputArtifact(artifact.Artifact): @@ -432,15 +433,23 @@ class SpecWithNonPrimitiveTypes(ComponentSpec): input=channel.Channel(type=_InputArtifact), output=channel.Channel(type=_OutputArtifact)) - # Verify exec_properties store parsed value when use_proto set to True. - expected_proto = text_format.Parse( + # Verify exec_properties stores the correct placeholder when use_proto set + # to True. + resolved_proto = placeholder_utils.resolve_placeholder_expression( + spec.exec_properties['config_proto'].encode(), + placeholder_utils.ResolutionContext( + exec_info=data_types.ExecutionInfo() + ) + ) + self.assertProtoEquals( """ - splits { - name: "name" - pattern: "pattern" - } - """, example_gen_pb2.Input()) - self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto']) + splits { + name: "name" + pattern: "pattern" + } + """, + resolved_proto + ) self.assertEqual(True, spec.exec_properties['boolean']) self.assertIsInstance(spec.exec_properties['list_config_proto'], list) self.assertEqual(spec.exec_properties['list_boolean'], [False, True])