Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply the ph.make_proto optimization to execution parameters with use_proto=True. #6869

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tfx/orchestration/experimental/core/task_gen_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,21 @@ def test_generate_resolved_info_with_dynamic_exec_prop(self):
resolved_info.input_and_params[0].exec_properties['input_str'],
)

def test_generate_resolved_info_with_ph_exec_parameter(self):
otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1)
otu.fake_component_output(self._mlmd_connection, self._transform)
resolved_info = task_gen_utils.generate_resolved_info(
self._mlmd_connection_manager,
node_proto_view.get_view(self._trainer),
self._pipeline,
)
self.assertProtoEquals(
"""
splits: "train"
""",
resolved_info.input_and_params[0].exec_properties['train_args'],
)

@parameterized.named_parameters(
dict(
testcase_name='per_execution_idx_latest',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.control_flow import for_each
from tfx.dsl.input_resolution.canned_resolver_functions import latest_created
from tfx.dsl.placeholder import placeholder as ph
from tfx.orchestration import pipeline as pipeline_lib
from tfx.proto import trainer_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import standard_artifacts

Expand Down Expand Up @@ -82,5 +84,12 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
assert trainer.node_info.id == 'my_trainer'
for value in trainer.inputs.inputs.values():
value.min_count = 1
train_args_proto = trainer_pb2.TrainArgs(splits=['train'])
train_args = ph.make_proto(train_args_proto)
trainer.parameters.parameters['train_args'].CopyFrom(
pipeline_pb2.Value(
placeholder=train_args.encode()
)
)

return compiled_pipeline
63 changes: 63 additions & 0 deletions tfx/orchestration/portable/inputs_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,69 @@ def test_resolve_dynamic_parameters(self):
dynamic_parameters, placeholder_utils.ResolutionContext()
)

def test_resolve_ph_execution_parameters(self):
execution_parameters = pipeline_pb2.NodeParameters()
text_format.Parse(
r"""
parameters: {
key: "train_args"
value: {
placeholder: {
operator: {
proto_op: {
expression: {
operator: {
make_proto_op: {
base: {
type_url: "type.googleapis.com/tensorflow.service.TrainArgs"
value: "\n\005train"
}
file_descriptors: {
file: {
name: "third_party/tfx/trainer.proto"
package: "tensorflow.service"
message_type {
name: "TrainArgs"
field {
name: "splits"
number: 1
label: LABEL_REPEATED
type: TYPE_STRING
}
}
syntax: "proto3"
}
}
}
}
}
}
}
}
}
}
""",
execution_parameters,
)
test_artifact = types.standard_artifacts.String()
test_artifact.uri = self.create_tempfile().full_path
test_artifact.value = 'testvalue'
input_dict = {'_test_placeholder': [test_artifact]}
exec_params_resolved = inputs_utils.resolve_dynamic_parameters(
execution_parameters,
placeholder_utils.ResolutionContext(
exec_info=data_types.ExecutionInfo(
input_dict=input_dict, pipeline_run_id='testrunid'
)
),
)
self.assertProtoEquals(
"""
splits: "train"
""",
exec_params_resolved['train_args'],
)


if __name__ == '__main__':
tf.test.main()
22 changes: 21 additions & 1 deletion tfx/types/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import inspect
import itertools
from typing import Any, Dict, List, Mapping, Optional, Type, cast
from typing import Any, cast, Dict, List, Mapping, Optional, Type

from tfx.dsl.component.experimental.json_compat import check_strict_json_compat
from tfx.dsl.placeholder import placeholder
Expand All @@ -31,6 +31,21 @@
# Use Any to avoid cyclic import.
_BaseNode = Any

# Execution parameters that have `use_proto=True` but cannot be optimized with
# Placeholder ph.make_proto.
# TODO(b/350820714): Placeholder needs to be supported at runtime so that
# TensorflowTrainerConfig placeholder can be used to create the Trainer and
# Tuner jobs.
# TODO(b/349459258): ExampleDiff executor needs to be updated to support
# placeholder proto fields not being present.
# TODO(b/352623284); DistributionValidator test needs to be updated to
# support placeholder proto.
_MAKE_PROTO_EXEMPT_EXEC_PARAMETERS = [
'tensorflow_trainer',
'example_diff_config',
'distribution_validator_config',
]


def _is_runtime_param(data: Any) -> bool:
return data.__class__.__name__ == 'RuntimeParameter'
Expand Down Expand Up @@ -229,11 +244,16 @@ def _parse_parameters(self, raw_args: Mapping[str, Any]):
if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) # pytype: disable=not-supported-yet
and value and not _is_runtime_param(value)) and not isinstance(
value, placeholder.Placeholder):
# If the parameter is defined with use_proto=True, convert the value to
# proto from dict or json string if necessary before creating the proto
# placeholder.
if arg.use_proto:
if isinstance(value, dict):
value = proto_utils.dict_to_proto(value, arg.type())
elif isinstance(value, str):
value = proto_utils.json_to_proto(value, arg.type())
if arg_name not in _MAKE_PROTO_EXEMPT_EXEC_PARAMETERS:
value = placeholder.make_proto(value)
else:
# Create deterministic json string as it will be stored in metadata
# for cache check.
Expand Down
27 changes: 18 additions & 9 deletions tfx/types/component_spec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import unittest

import tensorflow as tf
from tfx.dsl.compiler import placeholder_utils
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.placeholder import placeholder
from tfx.orchestration.portable import data_types
from tfx.proto import example_gen_pb2
from tfx.types import artifact
from tfx.types import channel
Expand All @@ -32,7 +34,6 @@
from tfx.utils import proto_utils

from google.protobuf import json_format
from google.protobuf import text_format


class _InputArtifact(artifact.Artifact):
Expand Down Expand Up @@ -432,15 +433,23 @@ class SpecWithNonPrimitiveTypes(ComponentSpec):
input=channel.Channel(type=_InputArtifact),
output=channel.Channel(type=_OutputArtifact))

# Verify exec_properties store parsed value when use_proto set to True.
expected_proto = text_format.Parse(
# Verify exec_properties stores the correct placeholder when use_proto set
# to True.
resolved_proto = placeholder_utils.resolve_placeholder_expression(
spec.exec_properties['config_proto'].encode(),
placeholder_utils.ResolutionContext(
exec_info=data_types.ExecutionInfo()
)
)
self.assertProtoEquals(
"""
splits {
name: "name"
pattern: "pattern"
}
""", example_gen_pb2.Input())
self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto'])
splits {
name: "name"
pattern: "pattern"
}
""",
resolved_proto
)
self.assertEqual(True, spec.exec_properties['boolean'])
self.assertIsInstance(spec.exec_properties['list_config_proto'], list)
self.assertEqual(spec.exec_properties['list_boolean'], [False, True])
Expand Down