diff --git a/tfx/orchestration/experimental/core/env.py b/tfx/orchestration/experimental/core/env.py index 41407bd6a0..6e2378d334 100644 --- a/tfx/orchestration/experimental/core/env.py +++ b/tfx/orchestration/experimental/core/env.py @@ -71,6 +71,16 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None: """Check if this orchestrator is capable of orchestrating the pipeline.""" + @abc.abstractmethod + def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline): + """Method for processing a pipeline at the end of its initialization, before it starts running. + + This *can* mutate the provided IR in-place. + + Args: + pipeline: The pipeline IR to process. + """ + class _DefaultEnv(Env): """Default environment.""" @@ -104,6 +114,9 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None: pass + def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline): + pass + _ENV = _DefaultEnv() diff --git a/tfx/orchestration/experimental/core/env_test.py b/tfx/orchestration/experimental/core/env_test.py index 4dce1b191c..7074565fa5 100644 --- a/tfx/orchestration/experimental/core/env_test.py +++ b/tfx/orchestration/experimental/core/env_test.py @@ -16,6 +16,7 @@ import tensorflow as tf from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import test_utils +from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import status as status_lib @@ -45,6 +46,9 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline) -> None: raise NotImplementedError() + def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline): + raise NotImplementedError() + class EnvTest(test_utils.TfxTest): diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index cf7672465f..d4e68f1ffd 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -249,7 +249,7 @@ def initiate_pipeline_start( raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=str(e) ) - + env.get_env().pipeline_start_postprocess(pipeline) return pstate.PipelineState.new( mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view ) diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index e08bd7b4b5..23fb433834 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -1148,6 +1148,17 @@ def _stop_pipeline(pipeline_state): self.assertEqual(expected_pipeline, pipeline_state_run2.pipeline) mock_snapshot.assert_called() + def test_initiate_pipeline_start_gets_post_processed(self): + with self._mlmd_connection as m: + with test_utils.pipeline_start_postprocess_env(): + pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC) + pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) + + self.assertEqual( + pipeline_state.pipeline.pipeline_info.id, + 'test_pipeline_postprocessed', + ) + @parameterized.named_parameters( dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')), dict( diff --git a/tfx/orchestration/experimental/core/test_utils.py b/tfx/orchestration/experimental/core/test_utils.py index 611e6b5b71..5371d28cf3 100644 --- a/tfx/orchestration/experimental/core/test_utils.py +++ b/tfx/orchestration/experimental/core/test_utils.py @@ -499,3 +499,13 @@ def concurrent_pipeline_runs_enabled(self) -> bool: return True return _TestEnv() + + +def pipeline_start_postprocess_env(): + + class _TestEnv(env._DefaultEnv): # pylint: disable=protected-access + + def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline): + pipeline.pipeline_info.id = pipeline.pipeline_info.id + '_postprocessed' + + return _TestEnv()