Skip to content

Commit

Permalink
Apply the ph.make_proto optimization to execution parameters with `…
Browse files Browse the repository at this point in the history
…use_proto=True`.

PiperOrigin-RevId: 642815569
  • Loading branch information
tfx-copybara committed Jul 2, 2024
1 parent 4e71a35 commit 031f263
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,11 @@ def generate_resolved_info(
)
raise

exec_properties.update(dynamic_exec_properties)
result.input_and_params.append(
InputAndParam(
input_artifacts=input_artifacts,
exec_properties={**exec_properties, **dynamic_exec_properties},
exec_properties=exec_properties,
)
)

Expand Down
15 changes: 15 additions & 0 deletions tfx/orchestration/experimental/core/task_gen_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,21 @@ def test_generate_resolved_info_with_dynamic_exec_prop(self):
resolved_info.input_and_params[0].exec_properties['input_str'],
)

def test_generate_resolved_info_with_ph_exec_parameter(self):
otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1)
otu.fake_component_output(self._mlmd_connection, self._transform)
resolved_info = task_gen_utils.generate_resolved_info(
self._mlmd_connection_manager,
node_proto_view.get_view(self._trainer),
self._pipeline,
)
self.assertProtoEquals(
"""
splits: "train"
""",
resolved_info.input_and_params[0].exec_properties['train_args'],
)

@parameterized.named_parameters(
dict(
testcase_name='per_execution_idx_latest',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.control_flow import for_each
from tfx.dsl.input_resolution.canned_resolver_functions import latest_created
from tfx.dsl.placeholder import placeholder as ph
from tfx.orchestration import pipeline as pipeline_lib
from tfx.proto import trainer_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import standard_artifacts

Expand Down Expand Up @@ -82,5 +84,12 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
assert trainer.node_info.id == 'my_trainer'
for value in trainer.inputs.inputs.values():
value.min_count = 1
train_args_proto = trainer_pb2.TrainArgs(splits=['train'])
train_args = ph.make_proto(train_args_proto)
trainer.parameters.parameters['train_args'].CopyFrom(
pipeline_pb2.Value(
placeholder=train_args.encode()
)
)

return compiled_pipeline
63 changes: 63 additions & 0 deletions tfx/orchestration/portable/inputs_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,69 @@ def test_resolve_dynamic_parameters(self):
dynamic_parameters, placeholder_utils.ResolutionContext()
)

def test_resolve_ph_execution_parameters(self):
execution_parameters = pipeline_pb2.NodeParameters()
text_format.Parse(
r"""
parameters: {
key: "train_args"
value: {
placeholder: {
operator: {
proto_op: {
expression: {
operator: {
make_proto_op: {
base: {
type_url: "type.googleapis.com/tensorflow.service.TrainArgs"
value: "\n\005train"
}
file_descriptors: {
file: {
name: "third_party/tfx/trainer.proto"
package: "tensorflow.service"
message_type {
name: "TrainArgs"
field {
name: "splits"
number: 1
label: LABEL_REPEATED
type: TYPE_STRING
}
}
syntax: "proto3"
}
}
}
}
}
}
}
}
}
}
""",
execution_parameters,
)
test_artifact = types.standard_artifacts.String()
test_artifact.uri = self.create_tempfile().full_path
test_artifact.value = 'testvalue'
input_dict = {'_test_placeholder': [test_artifact]}
exec_params_resolved = inputs_utils.resolve_dynamic_parameters(
execution_parameters,
placeholder_utils.ResolutionContext(
exec_info=data_types.ExecutionInfo(
input_dict=input_dict, pipeline_run_id='testrunid'
)
),
)
self.assertProtoEquals(
"""
splits: "train"
""",
exec_params_resolved['train_args'],
)


if __name__ == '__main__':
tf.test.main()
11 changes: 10 additions & 1 deletion tfx/types/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import inspect
import itertools
from typing import Any, Dict, List, Mapping, Optional, Type, cast
from typing import Any, cast, Dict, List, Mapping, Optional, Type

from tfx.dsl.component.experimental.json_compat import check_strict_json_compat
from tfx.dsl.placeholder import placeholder
Expand All @@ -31,6 +31,13 @@
# Use Any to avoid cyclic import.
_BaseNode = Any

# Execution parameters that have `use_proto=True` but cannot be optimized with
# Placeholder ph.make_proto.
_EXEMPT_EXEC_PARAMETERS = [
'tensorflow_trainer',
'example_diff_config'
]


def _is_runtime_param(data: Any) -> bool:
return data.__class__.__name__ == 'RuntimeParameter'
Expand Down Expand Up @@ -234,6 +241,8 @@ def _parse_parameters(self, raw_args: Mapping[str, Any]):
value = proto_utils.dict_to_proto(value, arg.type())
elif isinstance(value, str):
value = proto_utils.json_to_proto(value, arg.type())
if arg_name not in _EXEMPT_EXEC_PARAMETERS:
value = placeholder.make_proto(value)
else:
# Create deterministic json string as it will be stored in metadata
# for cache check.
Expand Down
24 changes: 16 additions & 8 deletions tfx/types/component_spec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import unittest

import tensorflow as tf
from tfx.dsl.compiler import placeholder_utils
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.placeholder import placeholder
from tfx.orchestration.portable import data_types
from tfx.proto import example_gen_pb2
from tfx.types import artifact
from tfx.types import channel
Expand All @@ -32,7 +34,6 @@
from tfx.utils import proto_utils

from google.protobuf import json_format
from google.protobuf import text_format


class _InputArtifact(artifact.Artifact):
Expand Down Expand Up @@ -433,14 +434,21 @@ class SpecWithNonPrimitiveTypes(ComponentSpec):
output=channel.Channel(type=_OutputArtifact))

# Verify exec_properties store parsed value when use_proto set to True.
expected_proto = text_format.Parse(
resolved_proto = placeholder_utils.resolve_placeholder_expression(
spec.exec_properties['config_proto'].encode(),
placeholder_utils.ResolutionContext(
exec_info=data_types.ExecutionInfo()
)
)
self.assertProtoEquals(
"""
splits {
name: "name"
pattern: "pattern"
}
""", example_gen_pb2.Input())
self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto'])
splits {
name: "name"
pattern: "pattern"
}
""",
resolved_proto
)
self.assertEqual(True, spec.exec_properties['boolean'])
self.assertIsInstance(spec.exec_properties['list_config_proto'], list)
self.assertEqual(spec.exec_properties['list_boolean'], [False, True])
Expand Down

0 comments on commit 031f263

Please sign in to comment.