Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[FEATURE] Support beam sample (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZYHowell authored Jul 1, 2022
1 parent bd47646 commit 14b3153
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 25 deletions.
45 changes: 41 additions & 4 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import jax
from jax import core, xla, device_put
from jax._src.api import ShapeDtypeStruct
from jax._src.lib import xla_bridge as xb, xla_extension as xe
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
from jax._src.tree_util import tree_leaves
from jax.abstract_arrays import array_types
from jax.core import ShapedArray
Expand All @@ -48,11 +48,14 @@
import alpa.collective as col
from alpa.global_env import global_config
from alpa.monkey_patch import set_override_backend
from alpa.shard_parallel.auto_sharding import LogicalDeviceMesh
from alpa.parallel_plan import PlacementSpec
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
LogicalDeviceMesh,
run_spmd_partitioner_pass)
from alpa.parallel_plan import PlacementSpec, StagePlan
from alpa.timer import timers
from alpa.util import (benchmark_func, list_gpu_info, OrderedSet,
update_jax_platform, is_ray_node_resource)
update_jax_platform, is_ray_node_resource,
get_index_select_computation)

if global_config.nccl_mode == "cupy":
import alpa.collective.worker_nccl_util_cupy as worker_nccl_util
Expand Down Expand Up @@ -608,6 +611,7 @@ class PhysicalDeviceMesh(ABC):
num_hosts: int
num_devices_per_host: int
mesh_id: int
operation_executables: dict

def get_signature(self) -> str:
"""Return a signature string that contains the mesh shape and GPU
Expand Down Expand Up @@ -810,6 +814,7 @@ def __init__(self, devices: Sequence["Device"] = None):
self.num_devices_per_host = len(self.devices)
self.mesh_id = 0
self.device_strs = []
self.operation_executables = {}

self.set_runtime_random_seed(global_config.runtime_random_seed)

Expand Down Expand Up @@ -898,6 +903,7 @@ def sync_workers(self):

def shutdown(self, forced=False):
self.sync_workers()
self.operation_executables.clear()


def device_id_to_str(host_ip, device_id, device_type="gpu"):
Expand Down Expand Up @@ -934,6 +940,7 @@ def __init__(self,
self.workers = None
self.launched = False
self.service_server = None
self.operation_executables = {}

if devices is not None:
if len(devices) != len(host_ids):
Expand Down Expand Up @@ -1301,6 +1308,7 @@ def shutdown(self, forced=False):
if not self.launched:
return
if not forced:
self.operation_executables.clear()
ray.get([w.shutdown.remote() for w in self.workers])
for worker in self.workers:
ray.kill(worker)
Expand All @@ -1309,6 +1317,7 @@ def shutdown(self, forced=False):
self.service_server.shutdown()
self.service_server = None
self.launched = False
self.operation_executables.clear() # clear with forced shutdown


########################################
Expand Down Expand Up @@ -1525,6 +1534,34 @@ def __float__(self):

# TODO(lmzheng): copy more functions from DeviceArray
# (jax/_src/device_array.py)
def index_select(self, dim, index):
"""Compile and run index select operation."""
# pylint: disable=import-outside-toplevel
from alpa.mesh_executable import NormalMeshDriverExecutable
if type(index) not in [ShapedArray, ShapeDtypeStruct]:
index = xla.canonicalize_dtype(index)
index_shape = xc.shape_from_pyval(index)
key = hash(("index_select", self.aval, dim, index_shape))
if key in self.device_mesh.operation_executables:
executable = self.device_mesh.operation_executables[key]
else:
index_aval = ShapedArray(index.shape, index.dtype)
c = get_index_select_computation(self.sharding_spec, dim, self.aval,
index_shape).as_hlo_module()
hlo_module = run_spmd_partitioner_pass(c,
self.device_mesh.num_devices)

as_option = AutoShardingOption()
strategy_config = StagePlan(global_config.compile_random_seed,
self.device_mesh.shape, 1 << 60,
as_option.all_reduce_threshold, None,
-1)
executable = NormalMeshDriverExecutable(self.device_mesh,
hlo_module, strategy_config,
[self.aval, index_aval],
[self.aval], [False, False])
self.device_mesh.operation_executables[key] = executable
return executable.launch_on_driver(self, index)

def __str__(self):
return (f"DistributedArray(sharding_spec={self.sharding_spec}, "
Expand Down
17 changes: 17 additions & 0 deletions alpa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,23 @@ def compile_concatenate(backend, mesh_shape, sharding_spec, batch_size,
return hlo_proto


def get_index_select_computation(sharding_spec, dim, aval, index_shape):
sharding = pxla.sharding_spec_sharding_proto(sharding_spec)
c = xc.XlaBuilder("index_select")
c.set_sharding(sharding)
operand = xc.ops.Parameter(
c, 0, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))
c.clear_sharding()
index = xc.ops.Parameter(c, 1, index_shape)
index_selected = xc.ops.IndexSelect(operand, index, dim)
sharding2 = xc.OpSharding()
sharding2.type = sharding.type.TUPLE
sharding2.tuple_shardings = [sharding]
c.set_sharding(sharding2)
c = c.build(xc.ops.Tuple(c, [index_selected]))
return c


def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec):
"""Return the shape of a shard."""
shape = []
Expand Down
15 changes: 6 additions & 9 deletions examples/opt_serving/model/opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class OPTConfig:
decoder_attention_heads: int = 12
decoder_input_dim: int = 768
decoder_ffn_embed_dim: int = 3072
batch_size: int = 1
pad: int = 1
activation_fn: str = 'relu'
dtype: any = jnp.float16
Expand Down Expand Up @@ -530,17 +529,18 @@ def init_model_aval(config):


def init_cache_aval(config, batch_size):
dtype = config.dtype
head_dim = config.decoder_embed_dim // config.decoder_attention_heads

all_cache = []
for i in range(config.decoder_layers):
layer_cache = (
jax.core.ShapedArray((batch_size, config.max_target_positions,
config.decoder_attention_heads, head_dim),
config.dtype),
dtype),
jax.core.ShapedArray((batch_size, config.max_target_positions,
config.decoder_attention_heads, head_dim),
config.dtype),
dtype),
jax.core.ShapedArray((batch_size,), jnp.int32),
)
all_cache.append(layer_cache)
Expand Down Expand Up @@ -679,9 +679,6 @@ def get_pipeshard_executable(config,
support_output_attentions=False,
support_output_hidden_states=False,
autoregressive=True):
if autoregressive:
assert num_micro_batches == 1, "we only support num_micro_batches=1 for autoregressive!"
assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!"

# Init model
model, params = init_model_aval(config)
Expand All @@ -708,9 +705,9 @@ def inference_step_with_cache(params, batch):
alpa.global_config.always_donate_micro_batch_vars = False
executable = inference_step_with_cache.get_executable(
params, {
"input_ids": jax.core.ShapedArray((1, 1), jnp.int32),
"position_ids": jax.core.ShapedArray((1, 1), jnp.int32),
"cache": init_cache_aval(config, 1),
"input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
"position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
"cache": init_cache_aval(config, batch_size),
})
else:

Expand Down
84 changes: 74 additions & 10 deletions examples/opt_serving/model/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from functools import partial
import os
from typing import Sequence, Any

import alpa
import jax
from jax import xla
from jax import ShapeDtypeStruct, ShapedArray
from jax._src.lib import xla_client as xc
from jax.core import Primitive
from jax.interpreters import pxla
from jax.interpreters.pxla import NoSharding, Replicated, ShardingSpec
import jax.numpy as jnp
import numpy as np
import torch
Expand All @@ -15,6 +22,21 @@
from examples.opt_serving.model.opt_utils import TransformerModelConfig


index_select_p = Primitive("index-select")
def jax_index_select(input, index, dim=0):
return index_select_p.bind(input, index, dim=dim)

def _index_select_eval(input, index, dim):
return input

def _index_select_translation(c, input, index, dim):
return xc.ops.IndexSelect(input, index, dim)

index_select_p.def_abstract_eval(_index_select_eval)
index_select_p.def_impl(partial(xla.apply_primitive, index_select_p))
xla.translations[index_select_p] = _index_select_translation


@dataclass
class InferenceFuncOutput(ModelOutput):
logits: Any = None
Expand Down Expand Up @@ -100,8 +122,46 @@ def __call__(self,
past_key_values = ret.past_key_values
return ret


def get_hf_gpt_model(model_name, device):
def _reorder_cache(self, past, beam_idx):
# Current beam_idx is a torch tensor from beam scorer. To speedup,
# we need to have alpa's own beam scorer
cache = {}
cpu_idx = beam_idx.to("cpu").numpy()
if type(cpu_idx) not in [ShapedArray, ShapeDtypeStruct]:
cpu_idx = xla.canonicalize_dtype(cpu_idx)

def to_mesh(mesh):
if mesh in cache:
return cache[mesh]
avals = [ShapedArray(cpu_idx.shape, cpu_idx.dtype)]
replicated_spec = ShardingSpec([NoSharding()] * len(cpu_idx.shape),
[Replicated(mesh.num_devices)])
specs = [replicated_spec]
indices = [pxla.spec_to_indices(cpu_idx.shape, replicated_spec)]
ary = mesh.shard_args_to_arrays(avals, indices, specs, [cpu_idx])[0]
cache[mesh] = ary
return ary

def single_element_reorder_cache(ary):
if hasattr(ary, "index_select"):
# Torch or Alpa path
device_idx = None
if hasattr(ary, "device"): # Torch to_device
device_idx = beam_idx.to(ary.device)
else:
device_idx = to_mesh(ary.device_mesh)
return ary.index_select(0, device_idx)
else:
# Jax path
return jax_index_select(ary, cpu_idx, 0)
return tuple(
tuple(
single_element_reorder_cache(past_state)
for past_state in layer_past)
for layer_past in past)


def get_hf_gpt_model(model_name, device, num_beams):
raw_model = GPT2LMHeadModel.from_pretrained(model_name)
raw_model = raw_model.to(device)

Expand All @@ -116,6 +176,7 @@ def inference_func(input_ids,
return InferenceFuncOutput(out.logits, out.past_key_values)

inference_func_config = raw_model.config
inference_func_config.num_beams = num_beams
transformer_config = TransformerModelConfig(
H=raw_model.config.n_embd,
L=raw_model.config.n_layer,
Expand All @@ -127,7 +188,7 @@ def inference_func(input_ids,
executable, transformer_config)


def get_hf_opt_model(model_name, device):
def get_hf_opt_model(model_name, device, num_beams):
raw_model = OPTForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if "cuda" in device else torch.float32)
Expand All @@ -150,7 +211,7 @@ def inference_func(input_ids,
output_hidden_states=output_hidden_states)
return InferenceFuncOutput(out.logits, out.past_key_values)

inference_func_config = InferenceFuncConfig()
inference_func_config = InferenceFuncConfig(num_beams=num_beams)
for key in inference_func_config.__dataclass_fields__.keys():
setattr(inference_func_config, key, getattr(raw_model.config, key))
transformer_config = TransformerModelConfig(
Expand All @@ -171,6 +232,7 @@ def get_model(model_name: str,
dtype=jnp.float16,
dummy=False,
batch_size=1,
num_beams=1,
decoding_length_per_step=1,
num_micro_batches=1,
support_output_attentions=False,
Expand All @@ -191,9 +253,9 @@ def get_model(model_name: str,
f"Cannot support num_micro_batches > 1 in autoregressive mode.")

if "gpt" in model_name:
return get_hf_gpt_model(model_name, device)
return get_hf_gpt_model(model_name, device, num_beams)
if "facebook/opt" in model_name:
return get_hf_opt_model(model_name, device)
return get_hf_opt_model(model_name, device, num_beams)

assert ("jax/opt" in model_name or "alpa/opt" in model_name)
name = model_name.split("-")[1].upper()
Expand All @@ -220,7 +282,7 @@ def get_model(model_name: str,

# load params
params = load_params_np(params_aval, path, config, dummy)
init_cache = init_cache_np(config, batch_size=1)
init_cache = init_cache_np(config, batch_size=batch_size * num_beams)
params, init_cache = jax.tree_map(jnp.array, (params, init_cache))
else:
assert "alpa/opt" in model_name
Expand All @@ -238,9 +300,11 @@ def get_model(model_name: str,
seq_len=config.max_target_positions,
vocab_size=config.vocab_size)

if autoregressive:
assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!"
executable, params_aval = get_pipeshard_executable(
config,
batch_size=batch_size,
batch_size=batch_size * num_beams,
num_micro_batches=num_micro_batches,
decoding_length_per_step=decoding_length_per_step,
support_output_attentions=support_output_attentions,
Expand All @@ -253,7 +317,7 @@ def get_model(model_name: str,
if autoregressive:
init_cache = init_cache_dis_array(executable,
config,
1,
batch_size * num_beams,
dummy=dummy)
set_skip_shard_args_check(init_cache)
executable.sync()
Expand Down Expand Up @@ -292,7 +356,7 @@ def inference_func(input_ids,
return InferenceFuncOutput(logits_step, output.attention_cache,
output.hidden_states, output.attentions)

inference_func_config = InferenceFuncConfig()
inference_func_config = InferenceFuncConfig(num_beams=num_beams)
return WrappedInferenceFunc(inference_func, inference_func_config,
executable, transformer_config)

Expand Down
9 changes: 7 additions & 2 deletions examples/opt_serving/textgen_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
tokenizer.add_bos_token = False

num_beams = 1
# Load the model
model = get_model(model_name="alpa/opt-2.7b",
device="cuda",
path="/home/ubuntu/opt_weights/")
path="/home/ubuntu/efs/parax-proj/",
num_beams=num_beams)

# Generate
prompt = "Paris is the capital city of"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
output = model.generate(input_ids=input_ids,
max_length=256,
do_sample=True,
num_beams=num_beams)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)

print(generated_string)

0 comments on commit 14b3153

Please sign in to comment.