From 005d3561f058981c174bc53cd2c48ceae4ffccfd Mon Sep 17 00:00:00 2001 From: tfx-team Date: Fri, 12 Jan 2024 06:30:47 -0800 Subject: [PATCH] Correctly resolve None items inside ListPlaceholder PiperOrigin-RevId: 597825457 --- tfx/dsl/compiler/placeholder_utils.py | 7 ++-- tfx/dsl/compiler/placeholder_utils_test.py | 37 +++++++++++++++++++ tfx/dsl/placeholder/placeholder_base.py | 3 +- tfx/dsl/placeholder/proto_placeholder_test.py | 30 ++++++++++++++- 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/tfx/dsl/compiler/placeholder_utils.py b/tfx/dsl/compiler/placeholder_utils.py index 979301bd51d..5e15cc38582 100644 --- a/tfx/dsl/compiler/placeholder_utils.py +++ b/tfx/dsl/compiler/placeholder_utils.py @@ -433,9 +433,10 @@ def _resolve_list_concat_operator( """Evaluates the list concat operator.""" result = [] for sub_expression in op.expressions: - value = self.resolve(sub_expression, pool) - if value is None: - raise NullDereferenceError(sub_expression) + try: + value = self.resolve(sub_expression, pool) + except NullDereferenceError: + value = None result.append(value) return result diff --git a/tfx/dsl/compiler/placeholder_utils_test.py b/tfx/dsl/compiler/placeholder_utils_test.py index f808e08dd25..b28fd9fe3cb 100644 --- a/tfx/dsl/compiler/placeholder_utils_test.py +++ b/tfx/dsl/compiler/placeholder_utils_test.py @@ -665,6 +665,43 @@ def testListConcat(self): placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), expected_result) + def testListConcatWithAbsentElement(self): + # When an exec prop has type Union[T, None] and the user passes None, it is + # actually completely absent from the exec_properties dict in + # ExecutionInvocation. See also b/172001324 and the corresponding todo in + # placeholder_utils.py. + placeholder_expression = """ + operator { + list_concat_op { + expressions { + value { + string_value: "random_before" + } + } + expressions { + placeholder { + type: EXEC_PROPERTY + key: "doesnotexist" + } + } + expressions { + value { + string_value: "random_after" + } + } + } + } + """ + pb = text_format.Parse( + placeholder_expression, placeholder_pb2.PlaceholderExpression() + ) + self.assertEqual( + placeholder_utils.resolve_placeholder_expression( + pb, self._resolution_context + ), + ["random_before", None, "random_after"], + ) + def testListConcatAndSerialize(self): placeholder_expression = """ operator { diff --git a/tfx/dsl/placeholder/placeholder_base.py b/tfx/dsl/placeholder/placeholder_base.py index b7d9aa251c1..74024d5b6bc 100644 --- a/tfx/dsl/placeholder/placeholder_base.py +++ b/tfx/dsl/placeholder/placeholder_base.py @@ -354,8 +354,7 @@ def serialize_list( """Serializes list-value placeholder to JSON or comma-separated string. Only supports primitive type list element (a.k.a bool, int, float or str) at - the - moment; throws runtime error otherwise. + the moment; throws runtime error otherwise. Args: serialization_format: The format of how the proto is serialized. diff --git a/tfx/dsl/placeholder/proto_placeholder_test.py b/tfx/dsl/placeholder/proto_placeholder_test.py index 36d472d2918..1b8975e3229 100644 --- a/tfx/dsl/placeholder/proto_placeholder_test.py +++ b/tfx/dsl/placeholder/proto_placeholder_test.py @@ -220,7 +220,8 @@ def test_NonePlaceholderIntoOptionalField(self): def test_NoneExecPropIntoOptionalField(self): # When an exec prop has type Union[T, None] and the user passes None, it is # actually completely absent from the exec_properties dict in - # ExecutionInvocation. + # ExecutionInvocation. See also b/172001324 and the corresponding todo in + # placeholder_utils.py. actual = resolve( _UpdateOptions(reload_policy=ph.exec_property('reload_policy')), exec_properties={}, # Intentionally empty. @@ -385,6 +386,33 @@ def test_RepeatedFieldFalsyItem(self): parse_text_proto(actual), ) + def test_RepeatedFieldNoneItem(self): + actual = resolve( + ph.make_proto( + execution_invocation_pb2.ExecutionInvocation( + pipeline_node=pipeline_pb2.PipelineNode() + ), + pipeline_node=ph.make_proto( + pipeline_pb2.PipelineNode(), + upstream_nodes=[ + 'foo', + ph.exec_property('reload_policy'), # Will be None. + 'bar', + ], + ), + ), + exec_properties={}, # Intentionally empty. + ) + self.assertProtoEquals( + """ + pipeline_node { + upstream_nodes: "foo" + upstream_nodes: "bar" + } + """, + parse_text_proto(actual), + ) + def test_NoneIntoRepeatedField(self): actual = resolve( ph.make_proto(pipeline_pb2.PipelineNode(), upstream_nodes=None)