diff --git a/RELEASE.md b/RELEASE.md index a703afedd1..63705f958f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -13,8 +13,6 @@ pre-defined schema file. ImportSchemaGen will replace `Importer` with simpler syntax and less constraints. You have to pass the file path to the schema file instead of the parent directory unlike `Importer`. -* Added support for outputting and encoding `tf.RaggedTensor`s in TFX - Transform component. ## Breaking Changes diff --git a/tfx/components/transform/executor.py b/tfx/components/transform/executor.py index 9823610df6..f618047f8e 100644 --- a/tfx/components/transform/executor.py +++ b/tfx/components/transform/executor.py @@ -16,7 +16,7 @@ import functools import hashlib import os -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union from absl import logging import apache_beam as beam @@ -275,24 +275,6 @@ def _InvokeStatsOptionsUpdaterFn( return stats_options_updater_fn(stats_type, tfdv.StatsOptions(**options)) -def _FilterInternalColumn( - record_batch: pa.RecordBatch, - internal_column_index: Optional[int] = None) -> pa.RecordBatch: - """Returns shallow copy of a RecordBatch with internal column removed.""" - if (internal_column_index is None and - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names): - return record_batch - else: - internal_column_index = ( - internal_column_index or - record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)) - # Making shallow copy since input modification is not allowed. - filtered_columns = list(record_batch.columns) - filtered_columns.pop(internal_column_index) - filtered_schema = record_batch.schema.remove(internal_column_index) - return pa.RecordBatch.from_arrays(filtered_columns, schema=filtered_schema) - - class Executor(base_beam_executor.BaseBeamExecutor): """Transform executor.""" @@ -693,7 +675,7 @@ def _GenerateAndMaybeValidateStats( generated_stats = ( pcoll - | 'FilterInternalColumn' >> beam.Map(_FilterInternalColumn) + | 'FilterInternalColumn' >> beam.Map(Executor._FilterInternalColumn) | 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options)) stats_result = ( @@ -751,42 +733,6 @@ def setup(self): def process(self, element: List[bytes]) -> Iterable[pa.RecordBatch]: yield self._decoder.DecodeBatch(element) - @beam.typehints.with_input_types(Tuple[pa.RecordBatch, Dict[str, pa.Array]]) - @beam.typehints.with_output_types(Tuple[Any, bytes]) - class _RecordBatchToExamplesFn(beam.DoFn): - """Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s.""" - - def __init__(self, schema: schema_pb2.Schema): - self._coder = tfx_bsl.coders.example_coder.RecordBatchToExamplesEncoder( - schema) - - def process( - self, data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]] - ) -> Iterable[Tuple[Any, bytes]]: - record_batch, unary_passthrough_features = data_batch - if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names: - keys_index = record_batch.schema.names.index( - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY) - keys = record_batch.column(keys_index).to_pylist() - # Filter the record batch to make sure that the internal column doesn't - # get encoded. - record_batch = _FilterInternalColumn(record_batch, keys_index) - examples = self._coder.encode(record_batch) - for key, example in zip(keys, examples): - yield (None if key is None else key[0], example) - else: - # Internal feature key is not present in the record batch but may be - # present in the unary pass-through features dict. - key = unary_passthrough_features.get( - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY, None) - if key is not None: - # The key is `pa.large_list()` and is, therefore, doubly nested. - key_list = key.to_pylist()[0] - key = None if key_list is None else key_list[0] - examples = self._coder.encode(record_batch) - for example in examples: - yield (key, example) - @beam.typehints.with_input_types(beam.Pipeline) class _OptimizeRun(beam.PTransform): """Utilizes TFT cache if applicable and removes unused datasets.""" @@ -1461,15 +1407,15 @@ def _ExtractRawExampleBatches(record_batch): | 'Transform[{}]'.format(infix) >> tft_beam.TransformDataset(output_record_batches=True)) - _, metadata = transform_fn - - # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in - # schema. Currently input dataset schema only contains dtypes, - # and other metadata is dropped due to roundtrip to tensors. - transformed_schema_proto = metadata.schema - if not disable_statistics: # Aggregated feature stats after transformation. + _, metadata = transform_fn + + # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in + # schema. Currently input dataset schema only contains dtypes, + # and other metadata is dropped due to roundtrip to tensors. + transformed_schema_proto = metadata.schema + for dataset in transform_data_list: infix = 'TransformIndex{}'.format(dataset.index) dataset.transformed_and_standardized = ( @@ -1543,8 +1489,8 @@ def _ExtractRawExampleBatches(record_batch): for dataset in transform_data_list: infix = 'TransformIndex{}'.format(dataset.index) (dataset.transformed - | 'EncodeAndSerialize[{}]'.format(infix) >> beam.ParDo( - self._RecordBatchToExamplesFn(transformed_schema_proto)) + | 'EncodeAndSerialize[{}]'.format(infix) >> beam.FlatMap( + Executor._RecordBatchToExamples) | 'Materialize[{}]'.format(infix) >> self._WriteExamples( materialization_format, dataset.materialize_output_path)) @@ -1732,6 +1678,54 @@ def _GetTFXIOPassthroughKeys() -> Optional[Set[str]]: """Always returns None.""" return None + @staticmethod + def _FilterInternalColumn( + record_batch: pa.RecordBatch, + internal_column_index: Optional[int] = None) -> pa.RecordBatch: + """Returns shallow copy of a batch with internal column removed.""" + if (internal_column_index is None and + _TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names): + return record_batch + else: + internal_column_index = ( + internal_column_index or + record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)) + # Making shallow copy since input modification is not allowed. + filtered_columns = list(record_batch.columns) + filtered_columns.pop(internal_column_index) + filtered_schema = record_batch.schema.remove(internal_column_index) + return pa.RecordBatch.from_arrays( + filtered_columns, schema=filtered_schema) + + @staticmethod + def _RecordBatchToExamples( + data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]] + ) -> Generator[Tuple[Any, bytes], None, None]: + """Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s.""" + record_batch, unary_passthrough_features = data_batch + if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names: + keys_index = record_batch.schema.names.index( + _TRANSFORM_INTERNAL_FEATURE_FOR_KEY) + keys = record_batch.column(keys_index).to_pylist() + # Filter the record batch to make sure that the internal column doesn't + # get encoded. + record_batch = Executor._FilterInternalColumn(record_batch, keys_index) + examples = tfx_bsl.coders.example_coder.RecordBatchToExamples( + record_batch) + for key, example in zip(keys, examples): + yield (None if key is None else key[0], example) + else: + # Internal feature key is not present in the record batch but may be + # present in the unary pass-through features dict. + key = unary_passthrough_features.get(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY, + None) + if key is not None: + key = None if key.to_pylist()[0] is None else key.to_pylist()[0][0] + examples = tfx_bsl.coders.example_coder.RecordBatchToExamples( + record_batch) + for example in examples: + yield (key, example) + # TODO(b/130885503): clean this up once the sketch-based generator is the # default. @staticmethod