From 73548b92a01ee49ba7c8173920620cef331fba6d Mon Sep 17 00:00:00 2001 From: kmonte Date: Tue, 16 Jul 2024 11:54:37 -0700 Subject: [PATCH] Update _PipelineIRCodec to use base dir encoded into pipeline IR PiperOrigin-RevId: 652927035 --- .../experimental/core/pipeline_ops_test.py | 41 ++++++++++++------- .../experimental/core/pipeline_state.py | 37 ++++++++++------- .../experimental/core/pipeline_state_test.py | 8 +++- .../experimental/core/sample_mlmd_creator.py | 37 ++++++++++------- .../core/testing/test_async_pipeline.py | 8 ++-- .../core/testing/test_manual_node.py | 12 +++--- .../core/testing/test_sync_pipeline.py | 20 ++++++--- 7 files changed, 104 insertions(+), 59 deletions(-) diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index e238ad39e9..a136622f36 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -564,11 +564,12 @@ def test_revive_pipeline_run_active_pipeline_run_concurrent_runs_disabled( def test_revive_pipeline_run_with_subpipelines(self): with self._mlmd_connection as m: - pipeline = test_sync_pipeline.create_pipeline_with_subpipeline() + pipeline = test_sync_pipeline.create_pipeline_with_subpipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { - constants.PIPELINE_ROOT_PARAMETER_NAME: '/path/to/root', constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0', }, ) @@ -820,11 +821,12 @@ def test_initiate_pipeline_start_with_partial_run_and_subpipeline( self, mock_snapshot, run_subpipeline ): with self._mlmd_connection as m: - pipeline = test_sync_pipeline.create_pipeline_with_subpipeline() + pipeline = test_sync_pipeline.create_pipeline_with_subpipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { - constants.PIPELINE_ROOT_PARAMETER_NAME: '/my/pipeline/root', constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run-0123', }, ) @@ -1519,7 +1521,9 @@ def test_record_orchestration_time(self, pipeline, expected_run_id): def test_record_orchestration_time_subpipeline(self): with self._mlmd_cm as mlmd_connection_manager: m = mlmd_connection_manager.primary_mlmd_handle - pipeline = test_sync_pipeline.create_pipeline_with_subpipeline() + pipeline = test_sync_pipeline.create_pipeline_with_subpipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { @@ -2653,13 +2657,8 @@ def test_executor_node_stop_then_start_flow( self.assertEqual(pstate.NodeState.STARTED, node_state.state) @parameterized.named_parameters( - dict( - testcase_name='async', pipeline=test_async_pipeline.create_pipeline() - ), - dict( - testcase_name='sync', - pipeline=test_sync_pipeline.create_pipeline(), - ), + dict(testcase_name='async', mode='async'), + dict(testcase_name='sync', mode='sync'), ) @mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator') @mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator') @@ -2667,8 +2666,16 @@ def test_pure_service_node_stop_then_start_flow( self, mock_async_task_gen, mock_sync_task_gen, - pipeline, + mode, ): + if mode == 'async': + pipeline = test_async_pipeline.create_pipeline( + temp_dir=self.create_tempdir().full_path + ) + else: + pipeline = test_sync_pipeline.create_pipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { @@ -2862,7 +2869,9 @@ def test_wait_for_predicate_timeout_secs_None(self, mock_sleep): self.assertEqual(mock_sleep.call_count, 2) def test_resume_manual_node(self): - pipeline = test_manual_node.create_pipeline() + pipeline = test_manual_node.create_pipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { @@ -3516,7 +3525,9 @@ def health_status(self) -> status_lib.Status: ) def test_delete_pipeline_run(self): - pipeline = test_sync_pipeline.create_pipeline() + pipeline = test_sync_pipeline.create_pipeline( + temp_dir=self.create_tempdir().full_path + ) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 9db976639d..76559e8391 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -424,29 +424,38 @@ def testonly_reset(cls) -> None: with cls._lock: cls._obj = None - def __init__(self): - self.base_dir = env.get_env().get_base_dir() - if self.base_dir: - self.pipeline_irs_dir = os.path.join(self.base_dir, - self._ORCHESTRATOR_METADATA_DIR, - self._PIPELINE_IRS_DIR) - fileio.makedirs(self.pipeline_irs_dir) - else: - self.pipeline_irs_dir = None - def encode(self, pipeline: pipeline_pb2.Pipeline) -> str: """Encodes pipeline IR.""" # Attempt to store as a base64 encoded string. If base_dir is provided # and the length is too large, store the IR on disk and retain the URL. # TODO(b/248786921): Always store pipeline IR to base_dir once the # accessibility issue is resolved. + + # Note that this setup means that every *subpipeline* will have its own + # "irs" dir. This is fine, though ideally we would put all pipeline IRs + # under the root pipeline dir, which would require us to *also* store the + # root pipeline dir in the IR. + + base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value + if base_dir: + pipeline_ir_dir = os.path.join( + base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR + ) + fileio.makedirs(pipeline_ir_dir) + else: + pipeline_ir_dir = None pipeline_encoded = _base64_encode(pipeline) max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length() - if self.base_dir and max_mlmd_str_value_len is not None and len( - pipeline_encoded) > max_mlmd_str_value_len: + if ( + base_dir + and pipeline_ir_dir + and max_mlmd_str_value_len is not None + and len(pipeline_encoded) > max_mlmd_str_value_len + ): pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id - pipeline_url = os.path.join(self.pipeline_irs_dir, - f'{pipeline_id}_{uuid.uuid4()}.pb') + pipeline_url = os.path.join( + pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb' + ) with fileio.open(pipeline_url, 'wb') as file: file.write(pipeline.SerializeToString()) pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url}) diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index cc6fd85056..dd001b1fe9 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -49,6 +49,7 @@ def _test_pipeline( param=1, pipeline_nodes: List[str] = None, pipeline_run_id: str = 'run0', + pipeline_root: str = '', ): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = pipeline_id @@ -63,6 +64,7 @@ def _test_pipeline( pipeline.runtime_spec.pipeline_run_id.field_value.string_value = ( pipeline_run_id ) + pipeline.runtime_spec.pipeline_root.field_value.string_value = pipeline_root return pipeline @@ -202,7 +204,11 @@ def test_encode_decode_with_base_dir(self): def test_encode_decode_exceeds_max_len(self): with TestEnv(self._pipeline_root, 0): - pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) + pipeline = _test_pipeline( + 'pipeline1', + pipeline_nodes=['Trainer'], + pipeline_root=self.create_tempdir().full_path, + ) pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline) self.assertEqual( pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded) diff --git a/tfx/orchestration/experimental/core/sample_mlmd_creator.py b/tfx/orchestration/experimental/core/sample_mlmd_creator.py index 217d89c0f0..cea0a85771 100644 --- a/tfx/orchestration/experimental/core/sample_mlmd_creator.py +++ b/tfx/orchestration/experimental/core/sample_mlmd_creator.py @@ -52,8 +52,12 @@ def _get_mlmd_connection(path: str) -> metadata.Metadata: return metadata.Metadata(connection_config=connection_config) -def _test_pipeline(ir_path: str, pipeline_id: str, run_id: str, - deployment_config: Optional[message.Message]): +def _test_pipeline( + ir_path: str, + pipeline_id: str, + run_id: str, + deployment_config: Optional[message.Message], +): """Creates test pipeline with pipeline_id and run_id.""" pipeline = pipeline_pb2.Pipeline() io_utils.parse_pbtxt_file(ir_path, pipeline) @@ -85,25 +89,30 @@ def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, ) -def _get_ir_path(external_ir_file: str): +def _get_ir_path(external_ir_file: str, temp_dir: str = ''): if external_ir_file: return external_ir_file ir_file_path = tempfile.mktemp(suffix='.pbtxt') - io_utils.write_pbtxt_file(ir_file_path, test_sync_pipeline.create_pipeline()) + io_utils.write_pbtxt_file( + ir_file_path, test_sync_pipeline.create_pipeline(temp_dir=temp_dir) + ) return ir_file_path -def create_sample_pipeline(m: metadata.Metadata, - pipeline_id: str, - run_num: int, - export_ir_path: str = '', - external_ir_file: str = '', - deployment_config: Optional[message.Message] = None, - execute_nodes_func: Callable[ - [metadata.Metadata, pipeline_pb2.Pipeline, int], - None] = _execute_nodes): +def create_sample_pipeline( + m: metadata.Metadata, + pipeline_id: str, + run_num: int, + export_ir_path: str = '', + external_ir_file: str = '', + deployment_config: Optional[message.Message] = None, + execute_nodes_func: Callable[ + [metadata.Metadata, pipeline_pb2.Pipeline, int], None + ] = _execute_nodes, + temp_dir: str = '', +): """Creates a list of pipeline and node execution.""" - ir_path = _get_ir_path(external_ir_file) + ir_path = _get_ir_path(external_ir_file, temp_dir=temp_dir) for i in range(run_num): run_id = 'run%02d' % i pipeline = _test_pipeline(ir_path, pipeline_id, run_id, deployment_config) diff --git a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py index 8c2aac7d90..452f3523cc 100644 --- a/tfx/orchestration/experimental/core/testing/test_async_pipeline.py +++ b/tfx/orchestration/experimental/core/testing/test_async_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Async pipeline for testing.""" +import os from tfx.dsl.compiler import compiler from tfx.dsl.component.experimental.annotations import InputArtifact @@ -51,7 +52,7 @@ def _trainer(examples: InputArtifact[standard_artifacts.Examples], del examples, transform_graph, model -def create_pipeline() -> pipeline_pb2.Pipeline: +def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline: """Creates an async pipeline for testing.""" # pylint: disable=no-value-for-parameter example_gen = _example_gen().with_id('my_example_gen') @@ -68,13 +69,14 @@ def create_pipeline() -> pipeline_pb2.Pipeline: pipeline = pipeline_lib.Pipeline( pipeline_name='my_pipeline', - pipeline_root='/path/to/root', + pipeline_root=os.path.join(temp_dir, 'path/to/root'), components=[ example_gen, transform, trainer, ], - execution_mode=pipeline_lib.ExecutionMode.ASYNC) + execution_mode=pipeline_lib.ExecutionMode.ASYNC, + ) dsl_compiler = compiler.Compiler(use_input_v2=True) compiled_pipeline: pipeline_pb2.Pipeline = dsl_compiler.compile(pipeline) diff --git a/tfx/orchestration/experimental/core/testing/test_manual_node.py b/tfx/orchestration/experimental/core/testing/test_manual_node.py index c246551001..31a746f28d 100644 --- a/tfx/orchestration/experimental/core/testing/test_manual_node.py +++ b/tfx/orchestration/experimental/core/testing/test_manual_node.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test pipeline with only manual node.""" +import os from tfx.dsl.compiler import compiler from tfx.dsl.components.common import manual_node @@ -19,16 +20,15 @@ from tfx.proto.orchestration import pipeline_pb2 -def create_pipeline() -> pipeline_pb2.Pipeline: +def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline: """Builds a test pipeline with only manual node.""" manual = manual_node.ManualNode(description='Do something.') pipeline = pipeline_lib.Pipeline( pipeline_name='my_pipeline', - pipeline_root='/path/to/root', - components=[ - manual - ], - enable_cache=True) + pipeline_root=os.path.join(temp_dir, 'path/to/root'), + components=[manual], + enable_cache=True, + ) dsl_compiler = compiler.Compiler() return dsl_compiler.compile(pipeline) diff --git a/tfx/orchestration/experimental/core/testing/test_sync_pipeline.py b/tfx/orchestration/experimental/core/testing/test_sync_pipeline.py index 8ba9d786f5..129f2af7b4 100644 --- a/tfx/orchestration/experimental/core/testing/test_sync_pipeline.py +++ b/tfx/orchestration/experimental/core/testing/test_sync_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sync pipeline for testing.""" +import os from tfx.dsl.compiler import compiler from tfx.dsl.component.experimental.annotations import InputArtifact @@ -82,7 +83,7 @@ def _chore(): pass -def create_pipeline() -> pipeline_pb2.Pipeline: +def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline: """Builds a test pipeline. ┌───────────┐ @@ -107,6 +108,10 @@ def create_pipeline() -> pipeline_pb2.Pipeline: │chore_b │ └────────┘ + Args: + temp_dir: If provieded, a temporary test directory to use as prefix to the + pipeline root. + Returns: A pipeline proto for the above DAG """ @@ -142,7 +147,7 @@ def create_pipeline() -> pipeline_pb2.Pipeline: pipeline = pipeline_lib.Pipeline( pipeline_name='my_pipeline', - pipeline_root='/path/to/root', + pipeline_root=os.path.join(temp_dir, 'path/to/root'), components=[ example_gen, stats_gen, @@ -154,7 +159,8 @@ def create_pipeline() -> pipeline_pb2.Pipeline: chore_a, chore_b, ], - enable_cache=True) + enable_cache=True, + ) dsl_compiler = compiler.Compiler() return dsl_compiler.compile(pipeline) @@ -300,7 +306,9 @@ def create_resource_lifetime_pipeline() -> pipeline_pb2.Pipeline: return dsl_compiler.compile(pipeline) -def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline: +def create_pipeline_with_subpipeline( + temp_dir: str = '/', +) -> pipeline_pb2.Pipeline: """Creates a pipeline with a subpipeline.""" # pylint: disable=no-value-for-parameter example_gen = _example_gen().with_id('my_example_gen') @@ -318,7 +326,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline: componsable_pipeline = pipeline_lib.Pipeline( pipeline_name='sub-pipeline', - pipeline_root='/path/to/root/sub', + pipeline_root=os.path.join(temp_dir, 'path/to/root/sub'), components=[stats_gen, schema_gen], enable_cache=True, inputs=p_in, @@ -332,7 +340,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline: pipeline = pipeline_lib.Pipeline( pipeline_name='my_pipeline', - pipeline_root='/path/to/root', + pipeline_root=os.path.join(temp_dir, 'path/to/root'), components=[ example_gen, componsable_pipeline,