Skip to content

Commit

Permalink
Revert "Introducing RecordBatchToExamplesEncoder to encode nested l…
Browse files Browse the repository at this point in the history
…ists representing `tf.RaggedTensor` as tf.Examples." (#4306)

This reverts commit f6beebf.
  • Loading branch information
Jiyong Jung authored Sep 24, 2021
1 parent 9f05db3 commit 7f432b9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 67 deletions.
2 changes: 0 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
pre-defined schema file. ImportSchemaGen will replace `Importer` with
simpler syntax and less constraints. You have to pass the file path to the
schema file instead of the parent directory unlike `Importer`.
* Added support for outputting and encoding `tf.RaggedTensor`s in TFX
Transform component.

## Breaking Changes

Expand Down
124 changes: 59 additions & 65 deletions tfx/components/transform/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
import hashlib
import os
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union

from absl import logging
import apache_beam as beam
Expand Down Expand Up @@ -275,24 +275,6 @@ def _InvokeStatsOptionsUpdaterFn(
return stats_options_updater_fn(stats_type, tfdv.StatsOptions(**options))


def _FilterInternalColumn(
record_batch: pa.RecordBatch,
internal_column_index: Optional[int] = None) -> pa.RecordBatch:
"""Returns shallow copy of a RecordBatch with internal column removed."""
if (internal_column_index is None and
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names):
return record_batch
else:
internal_column_index = (
internal_column_index or
record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY))
# Making shallow copy since input modification is not allowed.
filtered_columns = list(record_batch.columns)
filtered_columns.pop(internal_column_index)
filtered_schema = record_batch.schema.remove(internal_column_index)
return pa.RecordBatch.from_arrays(filtered_columns, schema=filtered_schema)


class Executor(base_beam_executor.BaseBeamExecutor):
"""Transform executor."""

Expand Down Expand Up @@ -693,7 +675,7 @@ def _GenerateAndMaybeValidateStats(

generated_stats = (
pcoll
| 'FilterInternalColumn' >> beam.Map(_FilterInternalColumn)
| 'FilterInternalColumn' >> beam.Map(Executor._FilterInternalColumn)
| 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options))

stats_result = (
Expand Down Expand Up @@ -751,42 +733,6 @@ def setup(self):
def process(self, element: List[bytes]) -> Iterable[pa.RecordBatch]:
yield self._decoder.DecodeBatch(element)

@beam.typehints.with_input_types(Tuple[pa.RecordBatch, Dict[str, pa.Array]])
@beam.typehints.with_output_types(Tuple[Any, bytes])
class _RecordBatchToExamplesFn(beam.DoFn):
"""Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s."""

def __init__(self, schema: schema_pb2.Schema):
self._coder = tfx_bsl.coders.example_coder.RecordBatchToExamplesEncoder(
schema)

def process(
self, data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]]
) -> Iterable[Tuple[Any, bytes]]:
record_batch, unary_passthrough_features = data_batch
if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names:
keys_index = record_batch.schema.names.index(
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)
keys = record_batch.column(keys_index).to_pylist()
# Filter the record batch to make sure that the internal column doesn't
# get encoded.
record_batch = _FilterInternalColumn(record_batch, keys_index)
examples = self._coder.encode(record_batch)
for key, example in zip(keys, examples):
yield (None if key is None else key[0], example)
else:
# Internal feature key is not present in the record batch but may be
# present in the unary pass-through features dict.
key = unary_passthrough_features.get(
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY, None)
if key is not None:
# The key is `pa.large_list()` and is, therefore, doubly nested.
key_list = key.to_pylist()[0]
key = None if key_list is None else key_list[0]
examples = self._coder.encode(record_batch)
for example in examples:
yield (key, example)

@beam.typehints.with_input_types(beam.Pipeline)
class _OptimizeRun(beam.PTransform):
"""Utilizes TFT cache if applicable and removes unused datasets."""
Expand Down Expand Up @@ -1461,15 +1407,15 @@ def _ExtractRawExampleBatches(record_batch):
| 'Transform[{}]'.format(infix) >>
tft_beam.TransformDataset(output_record_batches=True))

_, metadata = transform_fn

# TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
# schema. Currently input dataset schema only contains dtypes,
# and other metadata is dropped due to roundtrip to tensors.
transformed_schema_proto = metadata.schema

if not disable_statistics:
# Aggregated feature stats after transformation.
_, metadata = transform_fn

# TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
# schema. Currently input dataset schema only contains dtypes,
# and other metadata is dropped due to roundtrip to tensors.
transformed_schema_proto = metadata.schema

for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
dataset.transformed_and_standardized = (
Expand Down Expand Up @@ -1543,8 +1489,8 @@ def _ExtractRawExampleBatches(record_batch):
for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
(dataset.transformed
| 'EncodeAndSerialize[{}]'.format(infix) >> beam.ParDo(
self._RecordBatchToExamplesFn(transformed_schema_proto))
| 'EncodeAndSerialize[{}]'.format(infix) >> beam.FlatMap(
Executor._RecordBatchToExamples)
| 'Materialize[{}]'.format(infix) >> self._WriteExamples(
materialization_format, dataset.materialize_output_path))

Expand Down Expand Up @@ -1732,6 +1678,54 @@ def _GetTFXIOPassthroughKeys() -> Optional[Set[str]]:
"""Always returns None."""
return None

@staticmethod
def _FilterInternalColumn(
record_batch: pa.RecordBatch,
internal_column_index: Optional[int] = None) -> pa.RecordBatch:
"""Returns shallow copy of a batch with internal column removed."""
if (internal_column_index is None and
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names):
return record_batch
else:
internal_column_index = (
internal_column_index or
record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY))
# Making shallow copy since input modification is not allowed.
filtered_columns = list(record_batch.columns)
filtered_columns.pop(internal_column_index)
filtered_schema = record_batch.schema.remove(internal_column_index)
return pa.RecordBatch.from_arrays(
filtered_columns, schema=filtered_schema)

@staticmethod
def _RecordBatchToExamples(
data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]]
) -> Generator[Tuple[Any, bytes], None, None]:
"""Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s."""
record_batch, unary_passthrough_features = data_batch
if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names:
keys_index = record_batch.schema.names.index(
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)
keys = record_batch.column(keys_index).to_pylist()
# Filter the record batch to make sure that the internal column doesn't
# get encoded.
record_batch = Executor._FilterInternalColumn(record_batch, keys_index)
examples = tfx_bsl.coders.example_coder.RecordBatchToExamples(
record_batch)
for key, example in zip(keys, examples):
yield (None if key is None else key[0], example)
else:
# Internal feature key is not present in the record batch but may be
# present in the unary pass-through features dict.
key = unary_passthrough_features.get(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY,
None)
if key is not None:
key = None if key.to_pylist()[0] is None else key.to_pylist()[0][0]
examples = tfx_bsl.coders.example_coder.RecordBatchToExamples(
record_batch)
for example in examples:
yield (key, example)

# TODO(b/130885503): clean this up once the sketch-based generator is the
# default.
@staticmethod
Expand Down

0 comments on commit 7f432b9

Please sign in to comment.