Skip to content

Commit

Permalink
Add 'tfma_eval' model_type in model_specs to replace the hard-coded e…
Browse files Browse the repository at this point in the history
…stimator model with signature='eval' pattern.

PiperOrigin-RevId: 548028225
  • Loading branch information
genehwung authored and tfx-copybara committed Jul 14, 2023
1 parent d41a789 commit 96d4e43
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
4 changes: 3 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

### For Component Authors

* N/A
* Replace "tf_estimator" with "tfma_eval" as the identifier for tfma
EvalSavedModel. "tf_estimator" is now serves as the identifier for the normal
estimator model with any signature (by default 'serving').

## Deprecations

Expand Down
16 changes: 10 additions & 6 deletions tfx/components/evaluator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl import logging
import apache_beam as beam
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis import constants as tfma_constants
# Need to import the following module so that the fairness indicator post-export
# metric is registered.
import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators # pylint: disable=unused-import
Expand Down Expand Up @@ -174,10 +173,14 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
model_artifact = artifact_utils.get_single_instance(
input_dict[standard_component_specs.MODEL_KEY])
# TODO(b/171992041): tfma.get_model_type replaced by tfma.utils.
if ((hasattr(tfma, 'utils') and
tfma.utils.get_model_type(model_spec) == tfma.TF_ESTIMATOR) or
hasattr(tfma, 'get_model_type') and
tfma.get_model_type(model_spec) == tfma.TF_ESTIMATOR):
if (
(
hasattr(tfma, 'utils')
and tfma.utils.get_model_type(model_spec) == tfma.TFMA_EVAL
)
or hasattr(tfma, 'get_model_type')
and tfma.get_model_type(model_spec) == tfma.TFMA_EVAL
):
model_path = path_utils.eval_model_path(
model_artifact.uri,
path_utils.is_old_model_artifact(model_artifact))
Expand Down Expand Up @@ -248,7 +251,8 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
examples=input_dict[standard_component_specs.EXAMPLES_KEY],
telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
schema=schema,
raw_record_column_name=tfma_constants.ARROW_INPUT_COLUMN)
raw_record_column_name=tfma.constants.ARROW_INPUT_COLUMN,
)
# TODO(b/161935932): refactor after TFXIO supports multiple patterns.
for split in example_splits:
split_uris = artifact_utils.get_split_uris(
Expand Down

0 comments on commit 96d4e43

Please sign in to comment.