From 031f2633ed936729ce985e863c909cf00eb09881 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 12 Jun 2024 19:14:43 -0700 Subject: [PATCH] Apply the `ph.make_proto` optimization to execution parameters with `use_proto=True`. PiperOrigin-RevId: 642815569 --- .../experimental/core/task_gen_utils.py | 3 +- .../experimental/core/task_gen_utils_test.py | 15 +++++ .../core/testing/test_async_pipeline.py | 9 +++ .../portable/inputs_utils_test.py | 63 +++++++++++++++++++ tfx/types/component_spec.py | 11 +++- tfx/types/component_spec_test.py | 24 ++++--- 6 files changed, 115 insertions(+), 10 deletions(-) diff --git a/tfx/orchestration/experimental/core/task_gen_utils.py b/tfx/orchestration/experimental/core/task_gen_utils.py index 514c042fd21..8fa62aee907 100644 --- a/tfx/orchestration/experimental/core/task_gen_utils.py +++ b/tfx/orchestration/experimental/core/task_gen_utils.py @@ -280,10 +280,11 @@ def generate_resolved_info( ) raise + exec_properties.update(dynamic_exec_properties) result.input_and_params.append( InputAndParam( input_artifacts=input_artifacts, - exec_properties={**exec_properties, **dynamic_exec_properties}, + exec_properties=exec_properties, ) ) diff --git a/tfx/orchestration/experimental/core/task_gen_utils_test.py b/tfx/orchestration/experimental/core/task_gen_utils_test.py index 689bd2eb562..cff01b67406 100644 --- a/tfx/orchestration/experimental/core/task_gen_utils_test.py +++ b/tfx/orchestration/experimental/core/task_gen_utils_test.py @@ -473,6 +473,21 @@ def test_generate_resolved_info_with_dynamic_exec_prop(self): resolved_info.input_and_params[0].exec_properties['input_str'], ) + def test_generate_resolved_info_with_ph_exec_parameter(self): + otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1) + otu.fake_component_output(self._mlmd_connection, self._transform) + resolved_info = task_gen_utils.generate_resolved_info( + self._mlmd_connection_manager, + node_proto_view.get_view(self._trainer), + self._pipeline, + ) + self.assertProtoEquals( + """ + splits: "train" + """, + resolved_info.input_and_params[0].exec_properties['train_args'], + ) + @parameterized.named_parameters( dict( testcase_name='per_execution_idx_latest', diff --git a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py index 61279a08802..8c2aac7d908 100644 --- a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py +++ b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py @@ -20,7 +20,9 @@ from tfx.dsl.component.experimental.decorators import component from tfx.dsl.control_flow import for_each from tfx.dsl.input_resolution.canned_resolver_functions import latest_created +from tfx.dsl.placeholder import placeholder as ph from tfx.orchestration import pipeline as pipeline_lib +from tfx.proto import trainer_pb2 from tfx.proto.orchestration import pipeline_pb2 from tfx.types import standard_artifacts @@ -82,5 +84,12 @@ def create_pipeline() -> pipeline_pb2.Pipeline: assert trainer.node_info.id == 'my_trainer' for value in trainer.inputs.inputs.values(): value.min_count = 1 + train_args_proto = trainer_pb2.TrainArgs(splits=['train']) + train_args = ph.make_proto(train_args_proto) + trainer.parameters.parameters['train_args'].CopyFrom( + pipeline_pb2.Value( + placeholder=train_args.encode() + ) + ) return compiled_pipeline diff --git a/tfx/orchestration/portable/inputs_utils_test.py b/tfx/orchestration/portable/inputs_utils_test.py index 8e61c459020..326bae1ed40 100644 --- a/tfx/orchestration/portable/inputs_utils_test.py +++ b/tfx/orchestration/portable/inputs_utils_test.py @@ -385,6 +385,69 @@ def test_resolve_dynamic_parameters(self): dynamic_parameters, placeholder_utils.ResolutionContext() ) + def test_resolve_ph_execution_parameters(self): + execution_parameters = pipeline_pb2.NodeParameters() + text_format.Parse( + r""" + parameters: { + key: "train_args" + value: { + placeholder: { + operator: { + proto_op: { + expression: { + operator: { + make_proto_op: { + base: { + type_url: "type.googleapis.com/tensorflow.service.TrainArgs" + value: "\n\005train" + } + file_descriptors: { + file: { + name: "third_party/tfx/trainer.proto" + package: "tensorflow.service" + message_type { + name: "TrainArgs" + field { + name: "splits" + number: 1 + label: LABEL_REPEATED + type: TYPE_STRING + } + } + syntax: "proto3" + } + } + } + } + } + } + } + } + } + } + """, + execution_parameters, + ) + test_artifact = types.standard_artifacts.String() + test_artifact.uri = self.create_tempfile().full_path + test_artifact.value = 'testvalue' + input_dict = {'_test_placeholder': [test_artifact]} + exec_params_resolved = inputs_utils.resolve_dynamic_parameters( + execution_parameters, + placeholder_utils.ResolutionContext( + exec_info=data_types.ExecutionInfo( + input_dict=input_dict, pipeline_run_id='testrunid' + ) + ), + ) + self.assertProtoEquals( + """ + splits: "train" + """, + exec_params_resolved['train_args'], + ) + if __name__ == '__main__': tf.test.main() diff --git a/tfx/types/component_spec.py b/tfx/types/component_spec.py index 6abaf1a6db2..256a5d3444e 100644 --- a/tfx/types/component_spec.py +++ b/tfx/types/component_spec.py @@ -16,7 +16,7 @@ import copy import inspect import itertools -from typing import Any, Dict, List, Mapping, Optional, Type, cast +from typing import Any, cast, Dict, List, Mapping, Optional, Type from tfx.dsl.component.experimental.json_compat import check_strict_json_compat from tfx.dsl.placeholder import placeholder @@ -31,6 +31,13 @@ # Use Any to avoid cyclic import. _BaseNode = Any +# Execution parameters that have `use_proto=True` but cannot be optimized with +# Placeholder ph.make_proto. +_EXEMPT_EXEC_PARAMETERS = [ + 'tensorflow_trainer', + 'example_diff_config' +] + def _is_runtime_param(data: Any) -> bool: return data.__class__.__name__ == 'RuntimeParameter' @@ -234,6 +241,8 @@ def _parse_parameters(self, raw_args: Mapping[str, Any]): value = proto_utils.dict_to_proto(value, arg.type()) elif isinstance(value, str): value = proto_utils.json_to_proto(value, arg.type()) + if arg_name not in _EXEMPT_EXEC_PARAMETERS: + value = placeholder.make_proto(value) else: # Create deterministic json string as it will be stored in metadata # for cache check. diff --git a/tfx/types/component_spec_test.py b/tfx/types/component_spec_test.py index f1d3a3bfcdc..2bb24ccc2e6 100644 --- a/tfx/types/component_spec_test.py +++ b/tfx/types/component_spec_test.py @@ -19,8 +19,10 @@ import unittest import tensorflow as tf +from tfx.dsl.compiler import placeholder_utils from tfx.dsl.components.base.testing import test_node from tfx.dsl.placeholder import placeholder +from tfx.orchestration.portable import data_types from tfx.proto import example_gen_pb2 from tfx.types import artifact from tfx.types import channel @@ -32,7 +34,6 @@ from tfx.utils import proto_utils from google.protobuf import json_format -from google.protobuf import text_format class _InputArtifact(artifact.Artifact): @@ -433,14 +434,21 @@ class SpecWithNonPrimitiveTypes(ComponentSpec): output=channel.Channel(type=_OutputArtifact)) # Verify exec_properties store parsed value when use_proto set to True. - expected_proto = text_format.Parse( + resolved_proto = placeholder_utils.resolve_placeholder_expression( + spec.exec_properties['config_proto'].encode(), + placeholder_utils.ResolutionContext( + exec_info=data_types.ExecutionInfo() + ) + ) + self.assertProtoEquals( """ - splits { - name: "name" - pattern: "pattern" - } - """, example_gen_pb2.Input()) - self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto']) + splits { + name: "name" + pattern: "pattern" + } + """, + resolved_proto + ) self.assertEqual(True, spec.exec_properties['boolean']) self.assertIsInstance(spec.exec_properties['list_config_proto'], list) self.assertEqual(spec.exec_properties['list_boolean'], [False, True])