Skip to content

Commit

Permalink
Update _PipelineIRCodec to use base dir encoded into pipeline IR
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652927035
  • Loading branch information
kmonte authored and tfx-copybara committed Jul 19, 2024
1 parent e3ebdca commit 73548b9
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 59 deletions.
41 changes: 26 additions & 15 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,11 +564,12 @@ def test_revive_pipeline_run_active_pipeline_run_concurrent_runs_disabled(

def test_revive_pipeline_run_with_subpipelines(self):
with self._mlmd_connection as m:
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
constants.PIPELINE_ROOT_PARAMETER_NAME: '/path/to/root',
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0',
},
)
Expand Down Expand Up @@ -820,11 +821,12 @@ def test_initiate_pipeline_start_with_partial_run_and_subpipeline(
self, mock_snapshot, run_subpipeline
):
with self._mlmd_connection as m:
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
constants.PIPELINE_ROOT_PARAMETER_NAME: '/my/pipeline/root',
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run-0123',
},
)
Expand Down Expand Up @@ -1519,7 +1521,9 @@ def test_record_orchestration_time(self, pipeline, expected_run_id):
def test_record_orchestration_time_subpipeline(self):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
Expand Down Expand Up @@ -2653,22 +2657,25 @@ def test_executor_node_stop_then_start_flow(
self.assertEqual(pstate.NodeState.STARTED, node_state.state)

@parameterized.named_parameters(
dict(
testcase_name='async', pipeline=test_async_pipeline.create_pipeline()
),
dict(
testcase_name='sync',
pipeline=test_sync_pipeline.create_pipeline(),
),
dict(testcase_name='async', mode='async'),
dict(testcase_name='sync', mode='sync'),
)
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
def test_pure_service_node_stop_then_start_flow(
self,
mock_async_task_gen,
mock_sync_task_gen,
pipeline,
mode,
):
if mode == 'async':
pipeline = test_async_pipeline.create_pipeline(
temp_dir=self.create_tempdir().full_path
)
else:
pipeline = test_sync_pipeline.create_pipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
Expand Down Expand Up @@ -2862,7 +2869,9 @@ def test_wait_for_predicate_timeout_secs_None(self, mock_sleep):
self.assertEqual(mock_sleep.call_count, 2)

def test_resume_manual_node(self):
pipeline = test_manual_node.create_pipeline()
pipeline = test_manual_node.create_pipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
Expand Down Expand Up @@ -3516,7 +3525,9 @@ def health_status(self) -> status_lib.Status:
)

def test_delete_pipeline_run(self):
pipeline = test_sync_pipeline.create_pipeline()
pipeline = test_sync_pipeline.create_pipeline(
temp_dir=self.create_tempdir().full_path
)
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
Expand Down
37 changes: 23 additions & 14 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,29 +424,38 @@ def testonly_reset(cls) -> None:
with cls._lock:
cls._obj = None

def __init__(self):
self.base_dir = env.get_env().get_base_dir()
if self.base_dir:
self.pipeline_irs_dir = os.path.join(self.base_dir,
self._ORCHESTRATOR_METADATA_DIR,
self._PIPELINE_IRS_DIR)
fileio.makedirs(self.pipeline_irs_dir)
else:
self.pipeline_irs_dir = None

def encode(self, pipeline: pipeline_pb2.Pipeline) -> str:
"""Encodes pipeline IR."""
# Attempt to store as a base64 encoded string. If base_dir is provided
# and the length is too large, store the IR on disk and retain the URL.
# TODO(b/248786921): Always store pipeline IR to base_dir once the
# accessibility issue is resolved.

# Note that this setup means that every *subpipeline* will have its own
# "irs" dir. This is fine, though ideally we would put all pipeline IRs
# under the root pipeline dir, which would require us to *also* store the
# root pipeline dir in the IR.

base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value
if base_dir:
pipeline_ir_dir = os.path.join(
base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR
)
fileio.makedirs(pipeline_ir_dir)
else:
pipeline_ir_dir = None
pipeline_encoded = _base64_encode(pipeline)
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()
if self.base_dir and max_mlmd_str_value_len is not None and len(
pipeline_encoded) > max_mlmd_str_value_len:
if (
base_dir
and pipeline_ir_dir
and max_mlmd_str_value_len is not None
and len(pipeline_encoded) > max_mlmd_str_value_len
):
pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id
pipeline_url = os.path.join(self.pipeline_irs_dir,
f'{pipeline_id}_{uuid.uuid4()}.pb')
pipeline_url = os.path.join(
pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb'
)
with fileio.open(pipeline_url, 'wb') as file:
file.write(pipeline.SerializeToString())
pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url})
Expand Down
8 changes: 7 additions & 1 deletion tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _test_pipeline(
param=1,
pipeline_nodes: List[str] = None,
pipeline_run_id: str = 'run0',
pipeline_root: str = '',
):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = pipeline_id
Expand All @@ -63,6 +64,7 @@ def _test_pipeline(
pipeline.runtime_spec.pipeline_run_id.field_value.string_value = (
pipeline_run_id
)
pipeline.runtime_spec.pipeline_root.field_value.string_value = pipeline_root
return pipeline


Expand Down Expand Up @@ -202,7 +204,11 @@ def test_encode_decode_with_base_dir(self):

def test_encode_decode_exceeds_max_len(self):
with TestEnv(self._pipeline_root, 0):
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
pipeline = _test_pipeline(
'pipeline1',
pipeline_nodes=['Trainer'],
pipeline_root=self.create_tempdir().full_path,
)
pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline)
self.assertEqual(
pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded)
Expand Down
37 changes: 23 additions & 14 deletions tfx/orchestration/experimental/core/sample_mlmd_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ def _get_mlmd_connection(path: str) -> metadata.Metadata:
return metadata.Metadata(connection_config=connection_config)


def _test_pipeline(ir_path: str, pipeline_id: str, run_id: str,
deployment_config: Optional[message.Message]):
def _test_pipeline(
ir_path: str,
pipeline_id: str,
run_id: str,
deployment_config: Optional[message.Message],
):
"""Creates test pipeline with pipeline_id and run_id."""
pipeline = pipeline_pb2.Pipeline()
io_utils.parse_pbtxt_file(ir_path, pipeline)
Expand Down Expand Up @@ -85,25 +89,30 @@ def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline,
)


def _get_ir_path(external_ir_file: str):
def _get_ir_path(external_ir_file: str, temp_dir: str = ''):
if external_ir_file:
return external_ir_file
ir_file_path = tempfile.mktemp(suffix='.pbtxt')
io_utils.write_pbtxt_file(ir_file_path, test_sync_pipeline.create_pipeline())
io_utils.write_pbtxt_file(
ir_file_path, test_sync_pipeline.create_pipeline(temp_dir=temp_dir)
)
return ir_file_path


def create_sample_pipeline(m: metadata.Metadata,
pipeline_id: str,
run_num: int,
export_ir_path: str = '',
external_ir_file: str = '',
deployment_config: Optional[message.Message] = None,
execute_nodes_func: Callable[
[metadata.Metadata, pipeline_pb2.Pipeline, int],
None] = _execute_nodes):
def create_sample_pipeline(
m: metadata.Metadata,
pipeline_id: str,
run_num: int,
export_ir_path: str = '',
external_ir_file: str = '',
deployment_config: Optional[message.Message] = None,
execute_nodes_func: Callable[
[metadata.Metadata, pipeline_pb2.Pipeline, int], None
] = _execute_nodes,
temp_dir: str = '',
):
"""Creates a list of pipeline and node execution."""
ir_path = _get_ir_path(external_ir_file)
ir_path = _get_ir_path(external_ir_file, temp_dir=temp_dir)
for i in range(run_num):
run_id = 'run%02d' % i
pipeline = _test_pipeline(ir_path, pipeline_id, run_id, deployment_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Async pipeline for testing."""
import os

from tfx.dsl.compiler import compiler
from tfx.dsl.component.experimental.annotations import InputArtifact
Expand Down Expand Up @@ -51,7 +52,7 @@ def _trainer(examples: InputArtifact[standard_artifacts.Examples],
del examples, transform_graph, model


def create_pipeline() -> pipeline_pb2.Pipeline:
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
"""Creates an async pipeline for testing."""
# pylint: disable=no-value-for-parameter
example_gen = _example_gen().with_id('my_example_gen')
Expand All @@ -68,13 +69,14 @@ def create_pipeline() -> pipeline_pb2.Pipeline:

pipeline = pipeline_lib.Pipeline(
pipeline_name='my_pipeline',
pipeline_root='/path/to/root',
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
components=[
example_gen,
transform,
trainer,
],
execution_mode=pipeline_lib.ExecutionMode.ASYNC)
execution_mode=pipeline_lib.ExecutionMode.ASYNC,
)
dsl_compiler = compiler.Compiler(use_input_v2=True)
compiled_pipeline: pipeline_pb2.Pipeline = dsl_compiler.compile(pipeline)

Expand Down
12 changes: 6 additions & 6 deletions tfx/orchestration/experimental/core/testing/test_manual_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test pipeline with only manual node."""
import os

from tfx.dsl.compiler import compiler
from tfx.dsl.components.common import manual_node
from tfx.orchestration import pipeline as pipeline_lib
from tfx.proto.orchestration import pipeline_pb2


def create_pipeline() -> pipeline_pb2.Pipeline:
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
"""Builds a test pipeline with only manual node."""
manual = manual_node.ManualNode(description='Do something.')

pipeline = pipeline_lib.Pipeline(
pipeline_name='my_pipeline',
pipeline_root='/path/to/root',
components=[
manual
],
enable_cache=True)
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
components=[manual],
enable_cache=True,
)
dsl_compiler = compiler.Compiler()
return dsl_compiler.compile(pipeline)
20 changes: 14 additions & 6 deletions tfx/orchestration/experimental/core/testing/test_sync_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sync pipeline for testing."""
import os

from tfx.dsl.compiler import compiler
from tfx.dsl.component.experimental.annotations import InputArtifact
Expand Down Expand Up @@ -82,7 +83,7 @@ def _chore():
pass


def create_pipeline() -> pipeline_pb2.Pipeline:
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
"""Builds a test pipeline.
┌───────────┐
Expand All @@ -107,6 +108,10 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
│chore_b │
└────────┘
Args:
temp_dir: If provieded, a temporary test directory to use as prefix to the
pipeline root.
Returns:
A pipeline proto for the above DAG
"""
Expand Down Expand Up @@ -142,7 +147,7 @@ def create_pipeline() -> pipeline_pb2.Pipeline:

pipeline = pipeline_lib.Pipeline(
pipeline_name='my_pipeline',
pipeline_root='/path/to/root',
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
components=[
example_gen,
stats_gen,
Expand All @@ -154,7 +159,8 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
chore_a,
chore_b,
],
enable_cache=True)
enable_cache=True,
)
dsl_compiler = compiler.Compiler()
return dsl_compiler.compile(pipeline)

Expand Down Expand Up @@ -300,7 +306,9 @@ def create_resource_lifetime_pipeline() -> pipeline_pb2.Pipeline:
return dsl_compiler.compile(pipeline)


def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:
def create_pipeline_with_subpipeline(
temp_dir: str = '/',
) -> pipeline_pb2.Pipeline:
"""Creates a pipeline with a subpipeline."""
# pylint: disable=no-value-for-parameter
example_gen = _example_gen().with_id('my_example_gen')
Expand All @@ -318,7 +326,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:

componsable_pipeline = pipeline_lib.Pipeline(
pipeline_name='sub-pipeline',
pipeline_root='/path/to/root/sub',
pipeline_root=os.path.join(temp_dir, 'path/to/root/sub'),
components=[stats_gen, schema_gen],
enable_cache=True,
inputs=p_in,
Expand All @@ -332,7 +340,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:

pipeline = pipeline_lib.Pipeline(
pipeline_name='my_pipeline',
pipeline_root='/path/to/root',
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
components=[
example_gen,
componsable_pipeline,
Expand Down

0 comments on commit 73548b9

Please sign in to comment.