diff --git a/tfx/dsl/compiler/placeholder_utils.py b/tfx/dsl/compiler/placeholder_utils.py index 884b75e68f..3106bc7aa1 100644 --- a/tfx/dsl/compiler/placeholder_utils.py +++ b/tfx/dsl/compiler/placeholder_utils.py @@ -733,6 +733,16 @@ def _resolve_binary_logical_operator( raise ValueError(f"Unrecognized binary logical operation {op.op}.") + @_register(placeholder_pb2.DirNameOperator) + def _resolve_dir_name_operator( + self, + op: placeholder_pb2.DirNameOperator, + pool: Optional[descriptor_pool.DescriptorPool] = None, + ) -> str: + """Returns the directory name of the file.""" + path = self.resolve(op.expression, pool) + return os.path.dirname(path) + def debug_str(expression: placeholder_pb2.PlaceholderExpression) -> str: """Gets the debug string of a placeholder expression proto. @@ -876,6 +886,10 @@ def debug_str(expression: placeholder_pb2.PlaceholderExpression) -> str: ) return f"MakeProto({str(operator_pb.base).strip()}, {expression_str})" + if operator_name == "dir_name_op": + expression_str = debug_str(operator_pb.expression) + return f"dirname({expression_str})" + return "Unknown placeholder operator" return "Unknown placeholder expression" diff --git a/tfx/dsl/compiler/placeholder_utils_test.py b/tfx/dsl/compiler/placeholder_utils_test.py index 49fe6446d9..08fe38161e 100644 --- a/tfx/dsl/compiler/placeholder_utils_test.py +++ b/tfx/dsl/compiler/placeholder_utils_test.py @@ -116,7 +116,7 @@ } } } -output_metadata_uri: "test_executor_output_uri" +output_metadata_uri: "/execution_output_dir/file" input_dict { key: "examples" value { @@ -192,7 +192,7 @@ } } } -stateful_working_dir: "test_stateful_working_dir" +stateful_working_dir: "/stateful_working_dir/" pipeline_info { id: "test_pipeline_id" } @@ -233,15 +233,20 @@ def setUp(self): "proto_property": proto_utils.proto_to_json(self._serving_spec), "list_proto_property": [self._serving_spec], }, - execution_output_uri="test_executor_output_uri", - stateful_working_dir="test_stateful_working_dir", + execution_output_uri="/execution_output_dir/file", + stateful_working_dir="/stateful_working_dir/", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( - name="infra_validator"))), - pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id")), + name="infra_validator" + ) + ) + ), + pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id"), + ), executor_spec=executable_spec_pb2.PythonClassExecutableSpec( - class_path="test_class_path"), + class_path="test_class_path" + ), ) # Resolution context to simulate missing optional values. self._none_resolution_context = placeholder_utils.ResolutionContext( @@ -309,7 +314,7 @@ def testJoinPath(self): ) self.assertEqual( resolved_str, - "test_stateful_working_dir/foo/test_pipeline_id", + "/stateful_working_dir/foo/test_pipeline_id", ) def testArtifactProperty(self): @@ -823,7 +828,7 @@ def testMakeDict(self): ) expected_result = { "plain_key": 42, - "test_stateful_working_dir": "plain_value", + "/stateful_working_dir/": "plain_value", } self.assertEqual( placeholder_utils.resolve_placeholder_expression( @@ -1141,7 +1146,7 @@ def testExecutionInvocationPlaceholderAccessProtoField(self): placeholder_pb2.PlaceholderExpression()) resolved = placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context) - self.assertEqual(resolved, "test_stateful_working_dir") + self.assertEqual(resolved, "/stateful_working_dir/") def testExecutionInvocationDescriptor(self): # Test if ExecutionInvocation proto is in the default descriptor pool @@ -1634,6 +1639,7 @@ def testGetsOperatorsFromProtoReflection(self): "unary_logical_op", "artifact_property_op", "list_serialization_op", + "dir_name_op", }, ) self.assertSetEqual( @@ -1698,6 +1704,38 @@ def testMakeProtoOpResolvesProto(self): resolved_proto, ) + def testDirNameOp(self): + placeholder_expression = text_format.Parse( + r""" + operator { + dir_name_op { + expression { + operator { + proto_op { + expression { + placeholder { + type: EXEC_INVOCATION + } + } + proto_field_path: ".output_metadata_uri" + } + } + } + } + } + """, + placeholder_pb2.PlaceholderExpression(), + ) + resolved_result = placeholder_utils.resolve_placeholder_expression( + placeholder_expression, self._resolution_context + ) + self.assertEqual(resolved_result, "/execution_output_dir") + + actual = placeholder_utils.debug_str(placeholder_expression) + self.assertEqual( + actual, + "dirname(execution_invocation().output_metadata_uri)") + class PredicateResolutionTest(parameterized.TestCase, tf.test.TestCase): diff --git a/tfx/dsl/placeholder/placeholder.py b/tfx/dsl/placeholder/placeholder.py index 4f94a18f2f..43545b2293 100644 --- a/tfx/dsl/placeholder/placeholder.py +++ b/tfx/dsl/placeholder/placeholder.py @@ -17,6 +17,7 @@ # for historical reasons, it's not actually in the __init__ file. # pylint: disable=g-multiple-import,g-importing-member,unused-import,g-bad-import-order,redefined-builtin from tfx.dsl.placeholder.placeholder_base import Placeholder, Predicate, ListPlaceholder +from tfx.dsl.placeholder.placeholder_base import dirname from tfx.dsl.placeholder.placeholder_base import logical_not, logical_and, logical_or from tfx.dsl.placeholder.placeholder_base import join, join_path, make_list from tfx.dsl.placeholder.placeholder_base import ListSerializationFormat, ProtoSerializationFormat diff --git a/tfx/dsl/placeholder/placeholder_base.py b/tfx/dsl/placeholder/placeholder_base.py index 5d129a9fe2..07a792a7d7 100644 --- a/tfx/dsl/placeholder/placeholder_base.py +++ b/tfx/dsl/placeholder/placeholder_base.py @@ -757,6 +757,25 @@ def encode( return result +def dirname( + placeholder: Placeholder, +) -> _DirNameOperator: + """Runs os.path.dirname() on the path resolved from the input placeholder. + + Args: + placeholder: Another placeholder to be wrapped in a _DirNameOperator. + + Example: + ``` + ph.dirname(ph.execution_invocation().output_metadata_uri) + ``` + + Returns: + A _DirNameOperator operator. + """ + return _DirNameOperator(placeholder) + + class _ListSerializationOperator(UnaryPlaceholderOperator): """ListSerializationOperator serializes list type placeholder. @@ -810,6 +829,28 @@ class _CompareOp(enum.Enum): GREATER_THAN = placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN +class _DirNameOperator(UnaryPlaceholderOperator): + """_DirNameOperator returns directory path given a path.""" + + def __init__( + self, + value: Placeholder, + ): + super().__init__( + value, + expected_type=str, + ) + + def encode( + self, component_spec: Optional[type['types.ComponentSpec']] = None + ) -> placeholder_pb2.PlaceholderExpression: + result = placeholder_pb2.PlaceholderExpression() + op = result.operator.dir_name_op + op.expression.CopyFrom(self._value.encode(component_spec)) + + return result + + def internal_equals_value_like( a: Optional[ValueLikeType], b: Optional[ValueLikeType] ) -> bool: diff --git a/tfx/proto/orchestration/placeholder.proto b/tfx/proto/orchestration/placeholder.proto index 29710d8a1c..4aac0d6351 100644 --- a/tfx/proto/orchestration/placeholder.proto +++ b/tfx/proto/orchestration/placeholder.proto @@ -51,9 +51,16 @@ message PlaceholderExpressionOperator { ListConcatOperator list_concat_op = 12; MakeDictOperator make_dict_op = 13; MakeProtoOperator make_proto_op = 14; + DirNameOperator dir_name_op = 16; } } +// DirNameOperator extracts the directory name from a file path. +message DirNameOperator { + // Required. It must evaluate to a file path string. + PlaceholderExpression expression = 1; +} + // ArtifactUriOperator extracts the Artifact URI from a placeholder expression. // ArtifactUriOperator: Artifact -> String message ArtifactUriOperator {