Skip to content

Commit

Permalink
Automated rollback of commit 92fef51
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621274683
  • Loading branch information
tfx-copybara committed Apr 2, 2024
1 parent b8b9423 commit b0ab1f3
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 132 deletions.
1 change: 0 additions & 1 deletion build/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ sh_binary(
name = "gen_proto",
srcs = ["gen_proto.sh"],
data = [
"//tfx/dsl/component/experimental:annotations_test_proto_pb2.py",
"//tfx/examples/custom_components/presto_example_gen/proto:presto_config_pb2.py",
"//tfx/extensions/experimental/kfp_compatibility/proto:kfp_component_spec_pb2.py",
"//tfx/extensions/google_cloud_big_query/experimental/elwc_example_gen/proto:elwc_config_pb2.py",
Expand Down
25 changes: 0 additions & 25 deletions tfx/dsl/component/experimental/BUILD

This file was deleted.

39 changes: 11 additions & 28 deletions tfx/dsl/component/experimental/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from tfx.types import artifact
from tfx.utils import deprecation_utils

from google.protobuf import message

try:
import apache_beam as beam # pytype: disable=import-error # pylint: disable=g-import-not-at-top

Expand Down Expand Up @@ -109,43 +107,31 @@ def __repr__(self):
return '%s[%s]' % (self.__class__.__name__, self.type)


class _PrimitiveAndProtoTypeGenericMeta(type):
class _PrimitiveTypeGenericMeta(type):
"""Metaclass for _PrimitiveTypeGeneric, to enable primitive type indexing."""

def __getitem__(
cls: Type['_PrimitiveAndProtoTypeGeneric'],
params: Type[
Union[
int,
float,
str,
bool,
List[Any],
Dict[Any, Any],
message.Message,
],
],
cls: Type['_PrimitiveTypeGeneric'],
params: Type[Union[int, float, str, bool, List[Any], Dict[Any, Any]]],
):
"""Metaclass method allowing indexing class (`_PrimitiveTypeGeneric[T]`)."""
return cls._generic_getitem(params) # pytype: disable=attribute-error


class _PrimitiveAndProtoTypeGeneric(
metaclass=_PrimitiveAndProtoTypeGenericMeta
):
class _PrimitiveTypeGeneric(metaclass=_PrimitiveTypeGenericMeta):
"""A generic that takes a primitive type as its single argument."""

def __init__( # pylint: disable=invalid-name
self,
artifact_type: Type[Union[int, float, str, bool, message.Message]],
artifact_type: Type[Union[int, float, str, bool]],
_init_via_getitem=False,
):
if not _init_via_getitem:
class_name = self.__class__.__name__
raise ValueError(
(
'%s should be instantiated via the syntax `%s[T]`, where T is '
'`int`, `float`, `str`, `bool` or proto type.'
'`int`, `float`, `str`, or `bool`.'
)
% (class_name, class_name)
)
Expand All @@ -157,20 +143,17 @@ def _generic_getitem(cls, params):
# Check that the given parameter is a primitive type.
if (
inspect.isclass(params)
and (
params in (int, float, str, bool)
or issubclass(params, message.Message)
)
and params in (int, float, str, bool)
or json_compat.is_json_compatible(params)
):
return cls(params, _init_via_getitem=True)
else:
class_name = cls.__name__
raise ValueError(
(
'Generic type `%s[T]` expects the single parameter T to be `int`,'
' `float`, `str`, `bool`, JSON-compatible types (Dict[str, T],'
' List[T]) or a proto type. (got %r instead).'
'Generic type `%s[T]` expects the single parameter T to be '
'`int`, `float`, `str`, `bool` or JSON-compatible types '
'(Dict[str, T], List[T]) (got %r instead).'
)
% (class_name, params)
)
Expand Down Expand Up @@ -269,7 +252,7 @@ class AsyncOutputArtifact(Generic[T]):
"""Intermediate artifact object type annotation."""


class Parameter(_PrimitiveAndProtoTypeGeneric):
class Parameter(_PrimitiveTypeGeneric):
"""Component parameter type annotation."""


Expand Down
60 changes: 28 additions & 32 deletions tfx/dsl/component/experimental/annotations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import apache_beam as beam
import tensorflow as tf
from tfx.dsl.component.experimental import annotations
from tfx.dsl.component.experimental import annotations_test_proto_pb2
from tfx.types import artifact
from tfx.types import standard_artifacts
from tfx.types import value_artifact
Expand All @@ -28,21 +27,18 @@ class AnnotationsTest(tf.test.TestCase):

def testArtifactGenericAnnotation(self):
# Error: type hint whose parameter is not an Artifact subclass.
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
_ = annotations._ArtifactGeneric[int] # pytype: disable=unsupported-operands

# Error: type hint with abstract Artifact subclass.
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
_ = annotations._ArtifactGeneric[artifact.Artifact]

# Error: type hint with abstract Artifact subclass.
with self.assertRaisesRegex(
ValueError, 'expects .* a concrete subclass of'
):
with self.assertRaisesRegex(ValueError,
'expects .* a concrete subclass of'):
_ = annotations._ArtifactGeneric[value_artifact.ValueArtifact]

# OK.
Expand All @@ -53,55 +49,56 @@ def testArtifactAnnotationUsage(self):
_ = annotations.OutputArtifact[standard_artifacts.Examples]
_ = annotations.AsyncOutputArtifact[standard_artifacts.Model]

def testPrimitivAndProtoTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive or a proto type
def testPrimitiveTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive type
# pytype: disable=unsupported-operands
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric[artifact.Artifact]
_ = annotations._PrimitiveTypeGeneric[artifact.Artifact]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric[object]
_ = annotations._PrimitiveTypeGeneric[object]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric[123]
_ = annotations._PrimitiveTypeGeneric[123]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric['string']
_ = annotations._PrimitiveTypeGeneric['string']
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[int, int]]
_ = annotations._PrimitiveTypeGeneric[Dict[int, int]]
with self.assertRaisesRegex(
ValueError, 'T to be `int`, `float`, `str`, `bool`'
):
_ = annotations._PrimitiveAndProtoTypeGeneric[bytes]
_ = annotations._PrimitiveTypeGeneric[bytes]
# pytype: enable=unsupported-operands
# OK.
_ = annotations._PrimitiveAndProtoTypeGeneric[int]
_ = annotations._PrimitiveAndProtoTypeGeneric[float]
_ = annotations._PrimitiveAndProtoTypeGeneric[str]
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
_ = annotations._PrimitiveAndProtoTypeGeneric[Dict[str, float]]
_ = annotations._PrimitiveAndProtoTypeGeneric[bool]
_ = annotations._PrimitiveAndProtoTypeGeneric[
annotations_test_proto_pb2.TestMessage
]
_ = annotations._PrimitiveTypeGeneric[int]
_ = annotations._PrimitiveTypeGeneric[float]
_ = annotations._PrimitiveTypeGeneric[str]
_ = annotations._PrimitiveTypeGeneric[bool]
_ = annotations._PrimitiveTypeGeneric[Dict[str, float]]
_ = annotations._PrimitiveTypeGeneric[bool]

def testPipelineTypeGenericAnnotation(self):
# Error: type hint whose parameter is not a primitive type
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[artifact.Artifact]
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[object]
# pytype: disable=unsupported-operands
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric[123]
with self.assertRaisesRegex(ValueError, 'T to be `beam.Pipeline`'):
with self.assertRaisesRegex(
ValueError, 'T to be `beam.Pipeline`'):
_ = annotations._PipelineTypeGeneric['string']
# pytype: enable=unsupported-operands

Expand All @@ -113,7 +110,6 @@ def testParameterUsage(self):
_ = annotations.Parameter[float]
_ = annotations.Parameter[str]
_ = annotations.Parameter[bool]
_ = annotations.Parameter[annotations_test_proto_pb2.TestMessage]


if __name__ == '__main__':
Expand Down
21 changes: 0 additions & 21 deletions tfx/dsl/component/experimental/annotations_test_proto.proto

This file was deleted.

16 changes: 4 additions & 12 deletions tfx/dsl/component/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tfx.types import artifact
from tfx.types import component_spec
from tfx.types import system_executions
from google.protobuf import message


class ArgFormats(enum.Enum):
Expand Down Expand Up @@ -207,17 +206,10 @@ def _create_component_spec_class(
json_compatible_outputs[key],
)
if parameters:
for key, param_type in parameters.items():
if inspect.isclass(param_type) and issubclass(
param_type, message.Message
):
spec_parameters[key] = component_spec.ExecutionParameter(
type=param_type, optional=(key in arg_defaults), use_proto=True
)
else:
spec_parameters[key] = component_spec.ExecutionParameter(
type=param_type, optional=(key in arg_defaults)
)
for key, primitive_type in parameters.items():
spec_parameters[key] = component_spec.ExecutionParameter(
type=primitive_type, optional=(key in arg_defaults)
)
component_spec_class = type(
'%s_Spec' % func.__name__,
(tfx_types.ComponentSpec,),
Expand Down
14 changes: 1 addition & 13 deletions tfx/dsl/component/experimental/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Dict, List
import tensorflow as tf
from tfx.dsl.component.experimental import annotations
from tfx.dsl.component.experimental import annotations_test_proto_pb2
from tfx.dsl.component.experimental import decorators
from tfx.dsl.component.experimental import function_parser
from tfx.dsl.component.experimental import utils
Expand Down Expand Up @@ -95,9 +94,6 @@ def func_with_primitive_parameter(
float_param: annotations.Parameter[float],
str_param: annotations.Parameter[str],
bool_param: annotations.Parameter[bool],
proto_param: annotations.Parameter[
annotations_test_proto_pb2.TestMessage
],
dict_int_param: annotations.Parameter[Dict[str, int]],
list_bool_param: annotations.Parameter[List[bool]],
dict_list_bool_param: annotations.Parameter[Dict[str, List[bool]]],
Expand All @@ -116,7 +112,6 @@ def func_with_primitive_parameter(
'float_param': float,
'str_param': str,
'bool_param': bool,
'proto_param': annotations_test_proto_pb2.TestMessage,
'dict_int_param': Dict[str, int],
'list_bool_param': List[bool],
'dict_list_bool_param': Dict[str, List[bool]],
Expand Down Expand Up @@ -186,9 +181,6 @@ def func(
standard_artifacts.Examples
],
int_param: annotations.Parameter[int],
proto_param: annotations.Parameter[
annotations_test_proto_pb2.TestMessage
],
json_compat_param: annotations.Parameter[Dict[str, int]],
str_param: annotations.Parameter[str] = 'foo',
) -> annotations.OutputDict(
Expand Down Expand Up @@ -253,15 +245,11 @@ def func(
spec_outputs['map_str_float_output'].type, standard_artifacts.JsonValue
)
spec_parameter = actual_spec_class.PARAMETERS
self.assertLen(spec_parameter, 4)
self.assertLen(spec_parameter, 3)
self.assertEqual(spec_parameter['int_param'].type, int)
self.assertEqual(spec_parameter['int_param'].optional, False)
self.assertEqual(spec_parameter['str_param'].type, str)
self.assertEqual(spec_parameter['str_param'].optional, True)
self.assertEqual(
spec_parameter['proto_param'].type,
annotations_test_proto_pb2.TestMessage,
)
self.assertEqual(spec_parameter['json_compat_param'].type, Dict[str, int])
self.assertEqual(spec_parameter['json_compat_param'].optional, False)
self.assertEqual(actual_spec_class.TYPE_ANNOTATION, type_annotation)
Expand Down

0 comments on commit b0ab1f3

Please sign in to comment.