diff --git a/build/BUILD b/build/BUILD index 33d4efd72c..0d92eb4f8d 100644 --- a/build/BUILD +++ b/build/BUILD @@ -20,7 +20,6 @@ sh_binary( name = "gen_proto", srcs = ["gen_proto.sh"], data = [ - "//tfx/dsl/component/experimental:annotations_test_proto_pb2.py", "//tfx/examples/custom_components/presto_example_gen/proto:presto_config_pb2.py", "//tfx/extensions/experimental/kfp_compatibility/proto:kfp_component_spec_pb2.py", "//tfx/extensions/google_cloud_big_query/experimental/elwc_example_gen/proto:elwc_config_pb2.py", diff --git a/tfx/dsl/component/experimental/BUILD b/tfx/dsl/component/experimental/BUILD deleted file mode 100644 index 930e6d5594..0000000000 --- a/tfx/dsl/component/experimental/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tfx:tfx.bzl", "tfx_py_proto_library") - -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -tfx_py_proto_library( - name = "annotations_test_proto_py_pb2", - srcs = ["annotations_test_proto.proto"], -) diff --git a/tfx/dsl/component/experimental/annotations.py b/tfx/dsl/component/experimental/annotations.py index 2d61340dbc..3a33164080 100644 --- a/tfx/dsl/component/experimental/annotations.py +++ b/tfx/dsl/component/experimental/annotations.py @@ -23,8 +23,6 @@ from tfx.types import artifact from tfx.utils import deprecation_utils -from google.protobuf import message - try: import apache_beam as beam # pytype: disable=import-error # pylint: disable=g-import-not-at-top @@ -109,35 +107,23 @@ def __repr__(self): return '%s[%s]' % (self.__class__.__name__, self.type) -class _PrimitiveAndProtoTypeGenericMeta(type): +class _PrimitiveTypeGenericMeta(type): """Metaclass for _PrimitiveTypeGeneric, to enable primitive type indexing.""" def __getitem__( - cls: Type['_PrimitiveAndProtoTypeGeneric'], - params: Type[ - Union[ - int, - float, - str, - bool, - List[Any], - Dict[Any, Any], - message.Message, - ], - ], + cls: Type['_PrimitiveTypeGeneric'], + params: Type[Union[int, float, str, bool, List[Any], Dict[Any, Any]]], ): """Metaclass method allowing indexing class (`_PrimitiveTypeGeneric[T]`).""" return cls._generic_getitem(params) # pytype: disable=attribute-error -class _PrimitiveAndProtoTypeGeneric( - metaclass=_PrimitiveAndProtoTypeGenericMeta -): +class _PrimitiveTypeGeneric(metaclass=_PrimitiveTypeGenericMeta): """A generic that takes a primitive type as its single argument.""" def __init__( # pylint: disable=invalid-name self, - artifact_type: Type[Union[int, float, str, bool, message.Message]], + artifact_type: Type[Union[int, float, str, bool]], _init_via_getitem=False, ): if not _init_via_getitem: @@ -145,7 +131,7 @@ def __init__( # pylint: disable=invalid-name raise ValueError( ( '%s should be instantiated via the syntax `%s[T]`, where T is ' - '`int`, `float`, `str`, `bool` or proto type.' + '`int`, `float`, `str`, or `bool`.' ) % (class_name, class_name) ) @@ -157,10 +143,7 @@ def _generic_getitem(cls, params): # Check that the given parameter is a primitive type. if ( inspect.isclass(params) - and ( - params in (int, float, str, bool) - or issubclass(params, message.Message) - ) + and params in (int, float, str, bool) or json_compat.is_json_compatible(params) ): return cls(params, _init_via_getitem=True) @@ -168,9 +151,9 @@ def _generic_getitem(cls, params): class_name = cls.__name__ raise ValueError( ( - 'Generic type `%s[T]` expects the single parameter T to be `int`,' - ' `float`, `str`, `bool`, JSON-compatible types (Dict[str, T],' - ' List[T]) or a proto type. (got %r instead).' + 'Generic type `%s[T]` expects the single parameter T to be ' + '`int`, `float`, `str`, `bool` or JSON-compatible types ' + '(Dict[str, T], List[T]) (got %r instead).' ) % (class_name, params) ) @@ -269,7 +252,7 @@ class AsyncOutputArtifact(Generic[T]): """Intermediate artifact object type annotation.""" -class Parameter(_PrimitiveAndProtoTypeGeneric): +class Parameter(_PrimitiveTypeGeneric): """Component parameter type annotation.""" diff --git a/tfx/dsl/component/experimental/annotations_test.py b/tfx/dsl/component/experimental/annotations_test.py index 38970c38aa..c342bbfe15 100644 --- a/tfx/dsl/component/experimental/annotations_test.py +++ b/tfx/dsl/component/experimental/annotations_test.py @@ -18,7 +18,6 @@ import apache_beam as beam import tensorflow as tf from tfx.dsl.component.experimental import annotations -from tfx.dsl.component.experimental import annotations_test_proto_pb2 from tfx.types import artifact from tfx.types import standard_artifacts from tfx.types import value_artifact @@ -28,21 +27,18 @@ class AnnotationsTest(tf.test.TestCase): def testArtifactGenericAnnotation(self): # Error: type hint whose parameter is not an Artifact subclass. - with self.assertRaisesRegex( - ValueError, 'expects .* a concrete subclass of' - ): + with self.assertRaisesRegex(ValueError, + 'expects .* a concrete subclass of'): _ = annotations._ArtifactGeneric[int] # pytype: disable=unsupported-operands # Error: type hint with abstract Artifact subclass. - with self.assertRaisesRegex( - ValueError, 'expects .* a concrete subclass of' - ): + with self.assertRaisesRegex(ValueError, + 'expects .* a concrete subclass of'): _ = annotations._ArtifactGeneric[artifact.Artifact] # Error: type hint with abstract Artifact subclass. - with self.assertRaisesRegex( - ValueError, 'expects .* a concrete subclass of' - ): + with self.assertRaisesRegex(ValueError, + 'expects .* a concrete subclass of'): _ = annotations._ArtifactGeneric[value_artifact.ValueArtifact] # OK. @@ -53,55 +49,56 @@ def testArtifactAnnotationUsage(self): _ = annotations.OutputArtifact[standard_artifacts.Examples] _ = annotations.AsyncOutputArtifact[standard_artifacts.Model] - def testPrimitivAndProtoTypeGenericAnnotation(self): - # Error: type hint whose parameter is not a primitive or a proto type + def testPrimitiveTypeGenericAnnotation(self): + # Error: type hint whose parameter is not a primitive type # pytype: disable=unsupported-operands with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric[artifact.Artifact] + _ = annotations._PrimitiveTypeGeneric[artifact.Artifact] with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric[object] + _ = annotations._PrimitiveTypeGeneric[object] with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric[123] + _ = annotations._PrimitiveTypeGeneric[123] with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric['string'] + _ = annotations._PrimitiveTypeGeneric['string'] with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric[Dict[int, int]] + _ = annotations._PrimitiveTypeGeneric[Dict[int, int]] with self.assertRaisesRegex( ValueError, 'T to be `int`, `float`, `str`, `bool`' ): - _ = annotations._PrimitiveAndProtoTypeGeneric[bytes] + _ = annotations._PrimitiveTypeGeneric[bytes] # pytype: enable=unsupported-operands # OK. - _ = annotations._PrimitiveAndProtoTypeGeneric[int] - _ = annotations._PrimitiveAndProtoTypeGeneric[float] - _ = annotations._PrimitiveAndProtoTypeGeneric[str] - _ = annotations._PrimitiveAndProtoTypeGeneric[bool] - _ = annotations._PrimitiveAndProtoTypeGeneric[Dict[str, float]] - _ = annotations._PrimitiveAndProtoTypeGeneric[bool] - _ = annotations._PrimitiveAndProtoTypeGeneric[ - annotations_test_proto_pb2.TestMessage - ] + _ = annotations._PrimitiveTypeGeneric[int] + _ = annotations._PrimitiveTypeGeneric[float] + _ = annotations._PrimitiveTypeGeneric[str] + _ = annotations._PrimitiveTypeGeneric[bool] + _ = annotations._PrimitiveTypeGeneric[Dict[str, float]] + _ = annotations._PrimitiveTypeGeneric[bool] def testPipelineTypeGenericAnnotation(self): # Error: type hint whose parameter is not a primitive type - with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'): + with self.assertRaisesRegex( + ValueError, 'T to be `beam.Pipeline`'): _ = annotations._PipelineTypeGeneric[artifact.Artifact] - with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'): + with self.assertRaisesRegex( + ValueError, 'T to be `beam.Pipeline`'): _ = annotations._PipelineTypeGeneric[object] # pytype: disable=unsupported-operands - with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'): + with self.assertRaisesRegex( + ValueError, 'T to be `beam.Pipeline`'): _ = annotations._PipelineTypeGeneric[123] - with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'): + with self.assertRaisesRegex( + ValueError, 'T to be `beam.Pipeline`'): _ = annotations._PipelineTypeGeneric['string'] # pytype: enable=unsupported-operands @@ -113,7 +110,6 @@ def testParameterUsage(self): _ = annotations.Parameter[float] _ = annotations.Parameter[str] _ = annotations.Parameter[bool] - _ = annotations.Parameter[annotations_test_proto_pb2.TestMessage] if __name__ == '__main__': diff --git a/tfx/dsl/component/experimental/annotations_test_proto.proto b/tfx/dsl/component/experimental/annotations_test_proto.proto deleted file mode 100644 index cd9513c1d3..0000000000 --- a/tfx/dsl/component/experimental/annotations_test_proto.proto +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 Google LLC. All Rights Reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -syntax = "proto3"; - -package tfx.dsl.component.experimental; - -message TestMessage { - int32 number = 1; - string name = 2; -} diff --git a/tfx/dsl/component/experimental/utils.py b/tfx/dsl/component/experimental/utils.py index c04331d8c2..4053a3742c 100644 --- a/tfx/dsl/component/experimental/utils.py +++ b/tfx/dsl/component/experimental/utils.py @@ -25,7 +25,6 @@ from tfx.types import artifact from tfx.types import component_spec from tfx.types import system_executions -from google.protobuf import message class ArgFormats(enum.Enum): @@ -207,17 +206,10 @@ def _create_component_spec_class( json_compatible_outputs[key], ) if parameters: - for key, param_type in parameters.items(): - if inspect.isclass(param_type) and issubclass( - param_type, message.Message - ): - spec_parameters[key] = component_spec.ExecutionParameter( - type=param_type, optional=(key in arg_defaults), use_proto=True - ) - else: - spec_parameters[key] = component_spec.ExecutionParameter( - type=param_type, optional=(key in arg_defaults) - ) + for key, primitive_type in parameters.items(): + spec_parameters[key] = component_spec.ExecutionParameter( + type=primitive_type, optional=(key in arg_defaults) + ) component_spec_class = type( '%s_Spec' % func.__name__, (tfx_types.ComponentSpec,), diff --git a/tfx/dsl/component/experimental/utils_test.py b/tfx/dsl/component/experimental/utils_test.py index 30c5f0eeb5..cbb56e36ba 100644 --- a/tfx/dsl/component/experimental/utils_test.py +++ b/tfx/dsl/component/experimental/utils_test.py @@ -18,7 +18,6 @@ from typing import Dict, List import tensorflow as tf from tfx.dsl.component.experimental import annotations -from tfx.dsl.component.experimental import annotations_test_proto_pb2 from tfx.dsl.component.experimental import decorators from tfx.dsl.component.experimental import function_parser from tfx.dsl.component.experimental import utils @@ -95,9 +94,6 @@ def func_with_primitive_parameter( float_param: annotations.Parameter[float], str_param: annotations.Parameter[str], bool_param: annotations.Parameter[bool], - proto_param: annotations.Parameter[ - annotations_test_proto_pb2.TestMessage - ], dict_int_param: annotations.Parameter[Dict[str, int]], list_bool_param: annotations.Parameter[List[bool]], dict_list_bool_param: annotations.Parameter[Dict[str, List[bool]]], @@ -116,7 +112,6 @@ def func_with_primitive_parameter( 'float_param': float, 'str_param': str, 'bool_param': bool, - 'proto_param': annotations_test_proto_pb2.TestMessage, 'dict_int_param': Dict[str, int], 'list_bool_param': List[bool], 'dict_list_bool_param': Dict[str, List[bool]], @@ -186,9 +181,6 @@ def func( standard_artifacts.Examples ], int_param: annotations.Parameter[int], - proto_param: annotations.Parameter[ - annotations_test_proto_pb2.TestMessage - ], json_compat_param: annotations.Parameter[Dict[str, int]], str_param: annotations.Parameter[str] = 'foo', ) -> annotations.OutputDict( @@ -253,15 +245,11 @@ def func( spec_outputs['map_str_float_output'].type, standard_artifacts.JsonValue ) spec_parameter = actual_spec_class.PARAMETERS - self.assertLen(spec_parameter, 4) + self.assertLen(spec_parameter, 3) self.assertEqual(spec_parameter['int_param'].type, int) self.assertEqual(spec_parameter['int_param'].optional, False) self.assertEqual(spec_parameter['str_param'].type, str) self.assertEqual(spec_parameter['str_param'].optional, True) - self.assertEqual( - spec_parameter['proto_param'].type, - annotations_test_proto_pb2.TestMessage, - ) self.assertEqual(spec_parameter['json_compat_param'].type, Dict[str, int]) self.assertEqual(spec_parameter['json_compat_param'].optional, False) self.assertEqual(actual_spec_class.TYPE_ANNOTATION, type_annotation)