Skip to content

Commit

Permalink
New placeholder operator for getting file dir path
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654779027
  • Loading branch information
tfx-copybara committed Jul 24, 2024
1 parent db68ac5 commit 13b1153
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 10 deletions.
14 changes: 14 additions & 0 deletions tfx/dsl/compiler/placeholder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,16 @@ def _resolve_binary_logical_operator(

raise ValueError(f"Unrecognized binary logical operation {op.op}.")

@_register(placeholder_pb2.DirNameOperator)
def _resolve_dir_name_operator(
self,
op: placeholder_pb2.DirNameOperator,
pool: Optional[descriptor_pool.DescriptorPool] = None,
) -> str:
"""Returns the directory name of the file."""
path = self.resolve(op.expression, pool)
return os.path.dirname(path)


def debug_str(expression: placeholder_pb2.PlaceholderExpression) -> str:
"""Gets the debug string of a placeholder expression proto.
Expand Down Expand Up @@ -876,6 +886,10 @@ def debug_str(expression: placeholder_pb2.PlaceholderExpression) -> str:
)
return f"MakeProto({str(operator_pb.base).strip()}, {expression_str})"

if operator_name == "dir_name_op":
expression_str = debug_str(operator_pb.expression)
return f"dirname({expression_str})"

return "Unknown placeholder operator"

return "Unknown placeholder expression"
Expand Down
58 changes: 48 additions & 10 deletions tfx/dsl/compiler/placeholder_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
}
}
}
output_metadata_uri: "test_executor_output_uri"
output_metadata_uri: "/execution_output_dir/file"
input_dict {
key: "examples"
value {
Expand Down Expand Up @@ -192,7 +192,7 @@
}
}
}
stateful_working_dir: "test_stateful_working_dir"
stateful_working_dir: "/stateful_working_dir/"
pipeline_info {
id: "test_pipeline_id"
}
Expand Down Expand Up @@ -233,15 +233,20 @@ def setUp(self):
"proto_property": proto_utils.proto_to_json(self._serving_spec),
"list_proto_property": [self._serving_spec],
},
execution_output_uri="test_executor_output_uri",
stateful_working_dir="test_stateful_working_dir",
execution_output_uri="/execution_output_dir/file",
stateful_working_dir="/stateful_working_dir/",
pipeline_node=pipeline_pb2.PipelineNode(
node_info=pipeline_pb2.NodeInfo(
type=metadata_store_pb2.ExecutionType(
name="infra_validator"))),
pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id")),
name="infra_validator"
)
)
),
pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id"),
),
executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
class_path="test_class_path"),
class_path="test_class_path"
),
)
# Resolution context to simulate missing optional values.
self._none_resolution_context = placeholder_utils.ResolutionContext(
Expand Down Expand Up @@ -309,7 +314,7 @@ def testJoinPath(self):
)
self.assertEqual(
resolved_str,
"test_stateful_working_dir/foo/test_pipeline_id",
"/stateful_working_dir/foo/test_pipeline_id",
)

def testArtifactProperty(self):
Expand Down Expand Up @@ -823,7 +828,7 @@ def testMakeDict(self):
)
expected_result = {
"plain_key": 42,
"test_stateful_working_dir": "plain_value",
"/stateful_working_dir/": "plain_value",
}
self.assertEqual(
placeholder_utils.resolve_placeholder_expression(
Expand Down Expand Up @@ -1141,7 +1146,7 @@ def testExecutionInvocationPlaceholderAccessProtoField(self):
placeholder_pb2.PlaceholderExpression())
resolved = placeholder_utils.resolve_placeholder_expression(
pb, self._resolution_context)
self.assertEqual(resolved, "test_stateful_working_dir")
self.assertEqual(resolved, "/stateful_working_dir/")

def testExecutionInvocationDescriptor(self):
# Test if ExecutionInvocation proto is in the default descriptor pool
Expand Down Expand Up @@ -1634,6 +1639,7 @@ def testGetsOperatorsFromProtoReflection(self):
"unary_logical_op",
"artifact_property_op",
"list_serialization_op",
"dir_name_op",
},
)
self.assertSetEqual(
Expand Down Expand Up @@ -1698,6 +1704,38 @@ def testMakeProtoOpResolvesProto(self):
resolved_proto,
)

def testDirNameOp(self):
placeholder_expression = text_format.Parse(
r"""
operator {
dir_name_op {
expression {
operator {
proto_op {
expression {
placeholder {
type: EXEC_INVOCATION
}
}
proto_field_path: ".output_metadata_uri"
}
}
}
}
}
""",
placeholder_pb2.PlaceholderExpression(),
)
resolved_result = placeholder_utils.resolve_placeholder_expression(
placeholder_expression, self._resolution_context
)
self.assertEqual(resolved_result, "/execution_output_dir")

actual = placeholder_utils.debug_str(placeholder_expression)
self.assertEqual(
actual,
"dirname(execution_invocation().output_metadata_uri)")


class PredicateResolutionTest(parameterized.TestCase, tf.test.TestCase):

Expand Down
1 change: 1 addition & 0 deletions tfx/dsl/placeholder/placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# for historical reasons, it's not actually in the __init__ file.
# pylint: disable=g-multiple-import,g-importing-member,unused-import,g-bad-import-order,redefined-builtin
from tfx.dsl.placeholder.placeholder_base import Placeholder, Predicate, ListPlaceholder
from tfx.dsl.placeholder.placeholder_base import dirname
from tfx.dsl.placeholder.placeholder_base import logical_not, logical_and, logical_or
from tfx.dsl.placeholder.placeholder_base import join, join_path, make_list
from tfx.dsl.placeholder.placeholder_base import ListSerializationFormat, ProtoSerializationFormat
Expand Down
41 changes: 41 additions & 0 deletions tfx/dsl/placeholder/placeholder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,25 @@ def encode(
return result


def dirname(
placeholder: Placeholder,
) -> _DirNameOperator:
"""Runs os.path.dirname() on the path resolved from the input placeholder.
Args:
placeholder: Another placeholder to be wrapped in a _DirNameOperator.
Example:
```
ph.dirname(ph.execution_invocation().output_metadata_uri)
```
Returns:
A _DirNameOperator operator.
"""
return _DirNameOperator(placeholder)


class _ListSerializationOperator(UnaryPlaceholderOperator):
"""ListSerializationOperator serializes list type placeholder.
Expand Down Expand Up @@ -810,6 +829,28 @@ class _CompareOp(enum.Enum):
GREATER_THAN = placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN


class _DirNameOperator(UnaryPlaceholderOperator):
"""_DirNameOperator returns directory path given a path."""

def __init__(
self,
value: Placeholder,
):
super().__init__(
value,
expected_type=str,
)

def encode(
self, component_spec: Optional[type['types.ComponentSpec']] = None
) -> placeholder_pb2.PlaceholderExpression:
result = placeholder_pb2.PlaceholderExpression()
op = result.operator.dir_name_op
op.expression.CopyFrom(self._value.encode(component_spec))

return result


def internal_equals_value_like(
a: Optional[ValueLikeType], b: Optional[ValueLikeType]
) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions tfx/proto/orchestration/placeholder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,16 @@ message PlaceholderExpressionOperator {
ListConcatOperator list_concat_op = 12;
MakeDictOperator make_dict_op = 13;
MakeProtoOperator make_proto_op = 14;
DirNameOperator dir_name_op = 16;
}
}

// DirNameOperator extracts the directory name from a file path.
message DirNameOperator {
// Required. It must evaluate to a file path string.
PlaceholderExpression expression = 1;
}

// ArtifactUriOperator extracts the Artifact URI from a placeholder expression.
// ArtifactUriOperator: Artifact -> String
message ArtifactUriOperator {
Expand Down

0 comments on commit 13b1153

Please sign in to comment.