Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[Fix] Improve compilation speed by using Set instead of List for query (
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 1, 2022
1 parent 14b3153 commit 7b2a023
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions alpa/pipeline_parallel/runtime_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
import enum
import logging
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, Optional, Sequence, Union, Set

from jax._src.tree_util import PyTreeDef, tree_unflatten
from jax.core import Var
Expand Down Expand Up @@ -160,19 +160,20 @@ def flatten_uuid_set(container):
class PipelineInstEmitterHelper:
"""Environment for PipelineInstEmitter."""

def __init__(self, global_invars, grad_dummy_invars, is_batch,
schedule: PipelineSchedule):
self.global_invars = global_invars
self.global_batch_invars = OrderedSet(
v for v, b in zip(global_invars, is_batch) if b)
def __init__(self, global_invar_set: Set[Var],
global_batch_invar_set: Set[Var],
grad_dummy_invars: Dict[Var, Var], schedule: PipelineSchedule):
self.global_invar_set = global_invar_set
self.global_batch_invar_set = global_batch_invar_set
self.grad_dummy_invars = grad_dummy_invars
self.schedule = schedule
# Dict[var_key -> Dict[mesh_idx -> array_uuid]]
# The shape of the numpy array is [num_hosts, num_devices_per_host]
self.env = {}

def _get_var_key(self, var, batch_idx):
if var in self.global_invars and var not in self.global_batch_invars:
if (var in self.global_invar_set and
var not in self.global_batch_invar_set):
key = (var, 0)
elif (var in self.grad_dummy_invars and
batch_idx != self.schedule.first_backward_batch_index):
Expand Down Expand Up @@ -283,8 +284,12 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation],

##### Internal states #####
self.uuid_counter = 0 # counter for local buffer uuid
self.env = PipelineInstEmitterHelper(global_invars, grad_dummy_invars,
is_batch, schedule)
global_invar_set = OrderedSet(global_invars)
global_batch_invar_set = OrderedSet(
v for v, b in zip(global_invars, is_batch) if b)
self.env = PipelineInstEmitterHelper(global_invar_set,
global_batch_invar_set,
grad_dummy_invars, schedule)
self._communicator = None
self._resharding_tasks = [
[{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh)
Expand Down Expand Up @@ -390,12 +395,8 @@ def compile(self):
executable_config_lists)

# Split input into micro batches
global_batch_invar_set = OrderedSet([
var for var, batch in zip(self.global_invars, self.is_batch)
if batch
])
(input_config, input_shard_specs
) = self._compile_split_input_to_microbatches(global_batch_invar_set)
(input_config,
input_shard_specs) = self._compile_split_input_to_microbatches()

# Simulate the pipeline schedule and generate instructions
donation_mapping = [DisjointDict() for _ in range(num_mesh)]
Expand Down Expand Up @@ -618,7 +619,7 @@ def _compile_grad_buffer_allocations(self, executable_config_lists):

return grad_uuids, instruction_lists

def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
def _compile_collect_mesh_input(self, mesh_idx):
mesh_arg_set = OrderedSet()
var_to_spec = {}
mesh_batch_vars = OrderedSet()
Expand All @@ -630,9 +631,9 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]:
stage = self.stages[stage_idx]
for spec, invar in zip(stage.input_sharding_specs, stage.invars):
if invar in self.global_invars:
if invar in self.env.global_invar_set:
var_to_spec[invar] = spec
if invar in batch_vars:
if invar in self.env.global_batch_invar_set:
# Split batch arg
for batch_idx in range(num_batch):
mesh_arg_set.add((invar, batch_idx))
Expand Down Expand Up @@ -666,7 +667,7 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
return (mesh_arg_list, mesh_arg_indices, input_shard_indices,
input_shard_specs, mesh_invar_is_batch)

def _compile_split_input_to_microbatches(self, global_batch_invar_set):
def _compile_split_input_to_microbatches(self):
"""
Split batch arguments into micro batches.
Expand All @@ -675,10 +676,9 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1
"""
donated_invar_set = OrderedSet()
global_invar_set = OrderedSet(self.global_invars)
for stage in self.stages:
for invar, donate in zip(stage.invars, stage.donated_invars):
if donate and invar in global_invar_set:
if donate and invar in self.env.global_invar_set:
donated_invar_set.add(invar)
num_mesh = len(self.mesh_group)
mesh_arg_lists = [None for _ in range(num_mesh)]
Expand All @@ -692,13 +692,12 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
batch_invars = []
for mesh_idx in range(num_mesh):
(mesh_arg_list, arg_indices, shard_indices, shard_specs,
is_batch) = self._compile_collect_mesh_input(
mesh_idx, global_batch_invar_set)
is_batch) = self._compile_collect_mesh_input(mesh_idx)

mesh_arg_lists[mesh_idx] = mesh_arg_list
delete_after_run = [
var in donated_invar_set or
(var in global_batch_invar_set and
(var in self.env.global_batch_invar_set and
global_config.always_donate_micro_batch_vars)
for var, _ in mesh_arg_list
]
Expand Down

0 comments on commit 7b2a023

Please sign in to comment.