Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Merge PlacementSpec and LoadInfo (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jun 23, 2022
1 parent 3c2d3cc commit c7873aa
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 169 deletions.
3 changes: 2 additions & 1 deletion alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup
from alpa.mesh_executable import (NormalMeshDriverExecutable,
GradAccMeshDriverExecutable, PlacementSpec)
GradAccMeshDriverExecutable)
from alpa.parallel_plan import PlacementSpec
from alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable_internal
from alpa.pipeline_parallel.layer_construction import add_pipeline_marks_for_sliced_eqns
from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable
Expand Down
30 changes: 16 additions & 14 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from alpa.global_env import global_config
from alpa.monkey_patch import set_override_backend
from alpa.shard_parallel.auto_sharding import LogicalDeviceMesh
from alpa.parallel_plan import PlacementSpec
from alpa.timer import timers
from alpa.util import (benchmark_func, list_gpu_info, jax_tensor_to_cupy,
cupy_to_jax_tensor, jax_tensor_set,
Expand Down Expand Up @@ -290,9 +291,7 @@ def save_buffers(self, ckpt_dir: str, local_cache_dir: Union[str, None],
for uuid in uuids:
assert uuid in self.buffers

shard_names = [
str(self.host_id) + "." + str(i) for i in range(len(uuids))
]
shard_names = [f"shard_{self.host_id}.{i}" for i in range(len(uuids))]

metadata = {
"global_shape": global_shape,
Expand All @@ -313,7 +312,7 @@ def save_buffers(self, ckpt_dir: str, local_cache_dir: Union[str, None],
with open(os.path.join(save_dir, shard_name), "wb") as datafile:
np.save(datafile, self.buffers[uuid])

with open(os.path.join(save_dir, f".metadata{self.host_id}"),
with open(os.path.join(save_dir, f"metadata_{self.host_id}"),
"wb") as metafile:
pickle.dump(metadata, metafile)

Expand All @@ -325,7 +324,7 @@ def load_buffers(self, ckpt_dir: str, uuids: Sequence[int],
shard_indices: Sequence[Index], device_ids: Sequence[int]):
assert len(uuids) > 0
metadatas = list(
filter(lambda fname: fname.startswith(".metadata"),
filter(lambda fname: fname.startswith("metadata"),
os.listdir(ckpt_dir)))
# pylint: disable=import-outside-toplevel
from alpa.serialization import load_sharded_array
Expand Down Expand Up @@ -1892,26 +1891,29 @@ def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int):
cg = self.collective_groups[src_mesh_id][dst_mesh_id]
self._instantiate_nccl_group(cg)

def shard_args_to_arrays(self, load_infos: "LoadInfo", args: Sequence[Any]):
def shard_args_to_arrays(self, load_infos: PlacementSpec,
args: Sequence[Any]):
rets = []

for info, arg in zip(load_infos, args):
aval = info.aval
if info.is_replicated():
if len(info.mesh_ids) == 1:
mesh = self.meshes[info.mesh_ids[0]]
spec = info.sharding_specs[0]
indices = pxla.spec_to_indices(aval.shape, spec)
rets.append(
mesh.shard_args_to_arrays((aval,), (indices,), (spec,),
(arg,))[0])
else:
meshes, arrays = [], []
for mesh, spec in info.get_info():
for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):
mesh = self.meshes[mesh_id]
meshes.append(mesh)
indices = pxla.spec_to_indices(aval.shape, spec)
arrays.append(
mesh.shard_args_to_arrays((aval,), (indices,), (spec,),
(arg,))[0])
rets.append(ReplicatedDistributedArray(meshes, arrays))
else:
mesh, spec = info.get_info()
indices = pxla.spec_to_indices(aval.shape, spec)
rets.append(
mesh.shard_args_to_arrays((aval,), (indices,), (spec,),
(arg,))[0])

return rets

Expand Down
10 changes: 6 additions & 4 deletions alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,9 @@ def launch_on_driver(self, *args, **kwargs):
def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec((self.physical_mesh.mesh_id,), (sharding_spec,))
for sharding_spec in self.input_sharding_specs
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.input_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)

Expand Down Expand Up @@ -822,8 +823,9 @@ def launch_on_driver(self, *args):
def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec((self.physical_mesh.mesh_id,), (sharding_spec,))
for sharding_spec in self.global_arg_sharding_specs
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.global_arg_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)

Expand Down
2 changes: 2 additions & 0 deletions alpa/parallel_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from typing import Sequence, Tuple

import numpy as np
from jax.core import ShapedArray
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef


@dataclass
class PlacementSpec:
"""Specify how a tensor is stored distributedly."""
aval: ShapedArray
mesh_ids: Sequence[int]
sharding_specs: Sequence[pxla.ShardingSpec]

Expand Down
27 changes: 9 additions & 18 deletions alpa/pipeline_parallel/pipeshard_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import time
from typing import Optional, Sequence

from jax.tree_util import tree_map, tree_flatten, tree_unflatten, PyTreeDef
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
import numpy as np
import ray.exceptions

from alpa.device_mesh import MeshHostWorker
from alpa.global_env import global_config
from alpa.device_mesh import PhysicalDeviceMeshGroup
from alpa.mesh_executable import (
AllocZeroBufferWorkerExecutable, ConcatMeshWorkerExecutable,
MemzeroWorkerExecutable, PartialGradAccMeshWorkerExecutable,
next_mesh_executable_uuid, get_uuid_np_array, next_remote_buffer_uuid,
RemoteBufferRef, PlacementSpec)
from alpa.mesh_executable import (AllocZeroBufferWorkerExecutable,
ConcatMeshWorkerExecutable,
MemzeroWorkerExecutable,
PartialGradAccMeshWorkerExecutable,
next_mesh_executable_uuid, get_uuid_np_array,
next_remote_buffer_uuid, RemoteBufferRef)
from alpa.pipeline_parallel.runtime_emitter import (
AllocateZeroWorkerExecutableConfig, ConcatWorkerExecutableConfig,
ExecutableConfig, MemZeroWorkerExecutableConfig,
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self,
self.stages = pipeshard_config.xla_stages
self.schedule = pipeshard_config.schedule
self.flop_count = pipeshard_config.flop_count
self.load_info = pipeshard_config.load_info
self.input_placement_specs = pipeshard_config.input_placement_specs
# List[stage_idx -> str]
self.fully_optimized_hlo_texts = []
self.sharding_annotated_hlo_texts = (
Expand Down Expand Up @@ -204,12 +205,7 @@ def launch_on_driver(self, *args):

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""

def load_info_to_placement_spec(load_info):
return PlacementSpec([x.mesh_id for x in load_info.meshes],
load_info.specs)

return tree_map(load_info_to_placement_spec, self.load_info)
return self.input_placement_specs

def __call__(self, *args):
"""Fast call without signature matching."""
Expand All @@ -225,11 +221,6 @@ def __call__(self, *args):
out = self.launch_on_driver(*args_flat)
return tree_unflatten(self.out_tree, out)

##### Load/Store Related Functions #####
def get_load_info(self):
"""Get the load info for model checkpoints."""
return self.load_info

##### Profiling and Debugging Related Functions #####
def get_execution_time_costs(self,
warmup=2,
Expand Down
36 changes: 20 additions & 16 deletions alpa/pipeline_parallel/runtime_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from alpa.device_mesh import (DistributedArray, PhysicalDeviceMeshGroup,
ReplicatedDistributedArray)
from alpa.mesh_executable import next_mesh_executable_uuid
from alpa.parallel_plan import PlacementSpec
from alpa.pipeline_parallel.computation import XlaShardedPipelineComputation
from alpa.pipeline_parallel.schedules import PipelineSchedule
from alpa.pipeline_parallel.cross_mesh_resharding import (
CrossMeshCommunicator, SymbolicBroadcastReshardingTask,
SymbolicReshardingTask, ReshardingTask)
from alpa.serialization import LoadInfo
from alpa.util import (DisjointDict, OrderedSet, get_shard_shape,
get_microbatch_sharding_spec, compile_concatenate)

Expand Down Expand Up @@ -247,7 +247,7 @@ class PipeshardConfig:
output_local_uuid_list: Sequence[Sequence[int]]
outs_handler: Callable
# Others
load_info: LoadInfo
input_placement_specs: Sequence[PlacementSpec]
sharding_annotated_hlo_texts: Sequence[str]
flop_count: int

Expand Down Expand Up @@ -435,8 +435,8 @@ def compile(self):
worker, used_outside, donated, instruction_lists)

# Compile load info
load_info = self._compile_load_info(input_config.mesh_arg_indices,
input_shard_specs)
input_placement_specs = self._compile_input_placement_spec(
input_config.mesh_arg_indices, input_shard_specs)
return PipeshardConfig(
# Executable configs
instruction_lists,
Expand All @@ -455,7 +455,7 @@ def compile(self):
output_local_uuid_list,
outs_handler,
# Others
load_info,
input_placement_specs,
self.sharding_annotated_hlo_texts,
self.flop_count)

Expand Down Expand Up @@ -1013,23 +1013,27 @@ def outs_handler(mesh_group, bufs):

return outs_handler

def _compile_load_info(self, mesh_arg_indices, input_shard_specs):
def _compile_input_placement_spec(self, mesh_arg_indices,
input_shard_specs):
assert self.in_tree is not None

# build load_info_arr: flatten global index => LoadInfo object
load_info_arr = [None] * len(self.is_batch)
# build spec_arr: List[flatten global index -> PlacementSpec]
spec_arr = [None] * len(self.is_batch)
for mesh_idx, physical_mesh in enumerate(self.mesh_group):
for local_idx, global_idx in enumerate(mesh_arg_indices[mesh_idx]):
aval, mesh, spec = (self.global_invars[global_idx].aval,
physical_mesh,
input_shard_specs[mesh_idx][local_idx])
if load_info_arr[global_idx] is None:
load_info_arr[global_idx] = LoadInfo(aval, [mesh], [spec])
shard_spec = input_shard_specs[mesh_idx][local_idx]
if spec_arr[global_idx] is None:
spec_arr[global_idx] = PlacementSpec(
self.global_invars[global_idx].aval,
(physical_mesh.mesh_id,), (shard_spec,))
else:
load_info_arr[global_idx].add_replica(mesh, spec)
old_val = spec_arr[global_idx]
spec_arr[global_idx] = PlacementSpec(
old_val.aval,
old_val.mesh_ids + (physical_mesh.mesh_id,),
old_val.sharding_specs + (shard_spec,))

# build load_info_arr
return tree_unflatten(self.in_tree, load_info_arr)
return tree_unflatten(self.in_tree, spec_arr)

# TODO(yonghao): set empty buffer is not compatiable with local allgather
@staticmethod
Expand Down
Loading

0 comments on commit c7873aa

Please sign in to comment.