Skip to content

Commit

Permalink
Add pipeline_start_post_processing to Orchestrator env and call in pi…
Browse files Browse the repository at this point in the history
…peline_ops.intitiate_pipeline_start

PiperOrigin-RevId: 621556952
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 3, 2024
1 parent b0ab1f3 commit 135a198
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def set_health_status(self, status: status_lib.Status) -> None:
def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None:
"""Check if this orchestrator is capable of orchestrating the pipeline."""

@abc.abstractmethod
def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline):
"""Method for processing a pipeline at the end of its initialization, before it starts running.
This *can* mutate the provided IR in-place.
Args:
pipeline: The pipeline IR to process.
"""


class _DefaultEnv(Env):
"""Default environment."""
Expand Down Expand Up @@ -104,6 +114,9 @@ def set_health_status(self, status: status_lib.Status) -> None:
def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None:
pass

def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline):
pass


_ENV = _DefaultEnv()

Expand Down
4 changes: 4 additions & 0 deletions tfx/orchestration/experimental/core/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tensorflow as tf
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import test_utils
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import status as status_lib


Expand Down Expand Up @@ -45,6 +46,9 @@ def set_health_status(self, status: status_lib.Status) -> None:
def check_if_can_orchestrate(self, pipeline) -> None:
raise NotImplementedError()

def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline):
raise NotImplementedError()


class EnvTest(test_utils.TfxTest):

Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def initiate_pipeline_start(
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION, message=str(e)
)

env.get_env().pipeline_start_postprocess(pipeline)
return pstate.PipelineState.new(
mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view
)
Expand Down
11 changes: 11 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,17 @@ def _stop_pipeline(pipeline_state):
self.assertEqual(expected_pipeline, pipeline_state_run2.pipeline)
mock_snapshot.assert_called()

def test_initiate_pipeline_start_gets_post_processed(self):
with self._mlmd_connection as m:
with test_utils.pipeline_start_postprocess_env():
pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC)
pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)

self.assertEqual(
pipeline_state.pipeline.pipeline_info.id,
'test_pipeline_postprocessed',
)

@parameterized.named_parameters(
dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')),
dict(
Expand Down
10 changes: 10 additions & 0 deletions tfx/orchestration/experimental/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,13 @@ def concurrent_pipeline_runs_enabled(self) -> bool:
return True

return _TestEnv()


def pipeline_start_postprocess_env():

class _TestEnv(env._DefaultEnv): # pylint: disable=protected-access

def pipeline_start_postprocess(self, pipeline: pipeline_pb2.Pipeline):
pipeline.pipeline_info.id = pipeline.pipeline_info.id + '_postprocessed'

return _TestEnv()

0 comments on commit 135a198

Please sign in to comment.