diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 4a253b658..6a66cef97 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -32,7 +32,7 @@ import jax from jax import core, xla, device_put from jax._src.api import ShapeDtypeStruct -from jax._src.lib import xla_bridge as xb, xla_extension as xe +from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe from jax._src.tree_util import tree_leaves from jax.abstract_arrays import array_types from jax.core import ShapedArray @@ -48,11 +48,14 @@ import alpa.collective as col from alpa.global_env import global_config from alpa.monkey_patch import set_override_backend -from alpa.shard_parallel.auto_sharding import LogicalDeviceMesh -from alpa.parallel_plan import PlacementSpec +from alpa.shard_parallel.auto_sharding import (AutoShardingOption, + LogicalDeviceMesh, + run_spmd_partitioner_pass) +from alpa.parallel_plan import PlacementSpec, StagePlan from alpa.timer import timers from alpa.util import (benchmark_func, list_gpu_info, OrderedSet, - update_jax_platform, is_ray_node_resource) + update_jax_platform, is_ray_node_resource, + get_index_select_computation) if global_config.nccl_mode == "cupy": import alpa.collective.worker_nccl_util_cupy as worker_nccl_util @@ -608,6 +611,7 @@ class PhysicalDeviceMesh(ABC): num_hosts: int num_devices_per_host: int mesh_id: int + operation_executables: dict def get_signature(self) -> str: """Return a signature string that contains the mesh shape and GPU @@ -810,6 +814,7 @@ def __init__(self, devices: Sequence["Device"] = None): self.num_devices_per_host = len(self.devices) self.mesh_id = 0 self.device_strs = [] + self.operation_executables = {} self.set_runtime_random_seed(global_config.runtime_random_seed) @@ -898,6 +903,7 @@ def sync_workers(self): def shutdown(self, forced=False): self.sync_workers() + self.operation_executables.clear() def device_id_to_str(host_ip, device_id, device_type="gpu"): @@ -934,6 +940,7 @@ def __init__(self, self.workers = None self.launched = False self.service_server = None + self.operation_executables = {} if devices is not None: if len(devices) != len(host_ids): @@ -1301,6 +1308,7 @@ def shutdown(self, forced=False): if not self.launched: return if not forced: + self.operation_executables.clear() ray.get([w.shutdown.remote() for w in self.workers]) for worker in self.workers: ray.kill(worker) @@ -1309,6 +1317,7 @@ def shutdown(self, forced=False): self.service_server.shutdown() self.service_server = None self.launched = False + self.operation_executables.clear() # clear with forced shutdown ######################################## @@ -1525,6 +1534,34 @@ def __float__(self): # TODO(lmzheng): copy more functions from DeviceArray # (jax/_src/device_array.py) + def index_select(self, dim, index): + """Compile and run index select operation.""" + # pylint: disable=import-outside-toplevel + from alpa.mesh_executable import NormalMeshDriverExecutable + if type(index) not in [ShapedArray, ShapeDtypeStruct]: + index = xla.canonicalize_dtype(index) + index_shape = xc.shape_from_pyval(index) + key = hash(("index_select", self.aval, dim, index_shape)) + if key in self.device_mesh.operation_executables: + executable = self.device_mesh.operation_executables[key] + else: + index_aval = ShapedArray(index.shape, index.dtype) + c = get_index_select_computation(self.sharding_spec, dim, self.aval, + index_shape).as_hlo_module() + hlo_module = run_spmd_partitioner_pass(c, + self.device_mesh.num_devices) + + as_option = AutoShardingOption() + strategy_config = StagePlan(global_config.compile_random_seed, + self.device_mesh.shape, 1 << 60, + as_option.all_reduce_threshold, None, + -1) + executable = NormalMeshDriverExecutable(self.device_mesh, + hlo_module, strategy_config, + [self.aval, index_aval], + [self.aval], [False, False]) + self.device_mesh.operation_executables[key] = executable + return executable.launch_on_driver(self, index) def __str__(self): return (f"DistributedArray(sharding_spec={self.sharding_spec}, " diff --git a/alpa/util.py b/alpa/util.py index cd9ffbd02..7069706b2 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -560,6 +560,23 @@ def compile_concatenate(backend, mesh_shape, sharding_spec, batch_size, return hlo_proto +def get_index_select_computation(sharding_spec, dim, aval, index_shape): + sharding = pxla.sharding_spec_sharding_proto(sharding_spec) + c = xc.XlaBuilder("index_select") + c.set_sharding(sharding) + operand = xc.ops.Parameter( + c, 0, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype))) + c.clear_sharding() + index = xc.ops.Parameter(c, 1, index_shape) + index_selected = xc.ops.IndexSelect(operand, index, dim) + sharding2 = xc.OpSharding() + sharding2.type = sharding.type.TUPLE + sharding2.tuple_shardings = [sharding] + c.set_sharding(sharding2) + c = c.build(xc.ops.Tuple(c, [index_selected])) + return c + + def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec): """Return the shape of a shard.""" shape = [] diff --git a/examples/opt_serving/model/opt_model.py b/examples/opt_serving/model/opt_model.py index 9a2716d3f..b568da662 100644 --- a/examples/opt_serving/model/opt_model.py +++ b/examples/opt_serving/model/opt_model.py @@ -57,7 +57,6 @@ class OPTConfig: decoder_attention_heads: int = 12 decoder_input_dim: int = 768 decoder_ffn_embed_dim: int = 3072 - batch_size: int = 1 pad: int = 1 activation_fn: str = 'relu' dtype: any = jnp.float16 @@ -530,6 +529,7 @@ def init_model_aval(config): def init_cache_aval(config, batch_size): + dtype = config.dtype head_dim = config.decoder_embed_dim // config.decoder_attention_heads all_cache = [] @@ -537,10 +537,10 @@ def init_cache_aval(config, batch_size): layer_cache = ( jax.core.ShapedArray((batch_size, config.max_target_positions, config.decoder_attention_heads, head_dim), - config.dtype), + dtype), jax.core.ShapedArray((batch_size, config.max_target_positions, config.decoder_attention_heads, head_dim), - config.dtype), + dtype), jax.core.ShapedArray((batch_size,), jnp.int32), ) all_cache.append(layer_cache) @@ -679,9 +679,6 @@ def get_pipeshard_executable(config, support_output_attentions=False, support_output_hidden_states=False, autoregressive=True): - if autoregressive: - assert num_micro_batches == 1, "we only support num_micro_batches=1 for autoregressive!" - assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!" # Init model model, params = init_model_aval(config) @@ -708,9 +705,9 @@ def inference_step_with_cache(params, batch): alpa.global_config.always_donate_micro_batch_vars = False executable = inference_step_with_cache.get_executable( params, { - "input_ids": jax.core.ShapedArray((1, 1), jnp.int32), - "position_ids": jax.core.ShapedArray((1, 1), jnp.int32), - "cache": init_cache_aval(config, 1), + "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), + "position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), + "cache": init_cache_aval(config, batch_size), }) else: diff --git a/examples/opt_serving/model/wrapper.py b/examples/opt_serving/model/wrapper.py index d36ae37d7..c6f5540da 100644 --- a/examples/opt_serving/model/wrapper.py +++ b/examples/opt_serving/model/wrapper.py @@ -1,8 +1,15 @@ +from functools import partial import os from typing import Sequence, Any import alpa import jax +from jax import xla +from jax import ShapeDtypeStruct, ShapedArray +from jax._src.lib import xla_client as xc +from jax.core import Primitive +from jax.interpreters import pxla +from jax.interpreters.pxla import NoSharding, Replicated, ShardingSpec import jax.numpy as jnp import numpy as np import torch @@ -15,6 +22,21 @@ from examples.opt_serving.model.opt_utils import TransformerModelConfig +index_select_p = Primitive("index-select") +def jax_index_select(input, index, dim=0): + return index_select_p.bind(input, index, dim=dim) + +def _index_select_eval(input, index, dim): + return input + +def _index_select_translation(c, input, index, dim): + return xc.ops.IndexSelect(input, index, dim) + +index_select_p.def_abstract_eval(_index_select_eval) +index_select_p.def_impl(partial(xla.apply_primitive, index_select_p)) +xla.translations[index_select_p] = _index_select_translation + + @dataclass class InferenceFuncOutput(ModelOutput): logits: Any = None @@ -100,8 +122,46 @@ def __call__(self, past_key_values = ret.past_key_values return ret - -def get_hf_gpt_model(model_name, device): + def _reorder_cache(self, past, beam_idx): + # Current beam_idx is a torch tensor from beam scorer. To speedup, + # we need to have alpa's own beam scorer + cache = {} + cpu_idx = beam_idx.to("cpu").numpy() + if type(cpu_idx) not in [ShapedArray, ShapeDtypeStruct]: + cpu_idx = xla.canonicalize_dtype(cpu_idx) + + def to_mesh(mesh): + if mesh in cache: + return cache[mesh] + avals = [ShapedArray(cpu_idx.shape, cpu_idx.dtype)] + replicated_spec = ShardingSpec([NoSharding()] * len(cpu_idx.shape), + [Replicated(mesh.num_devices)]) + specs = [replicated_spec] + indices = [pxla.spec_to_indices(cpu_idx.shape, replicated_spec)] + ary = mesh.shard_args_to_arrays(avals, indices, specs, [cpu_idx])[0] + cache[mesh] = ary + return ary + + def single_element_reorder_cache(ary): + if hasattr(ary, "index_select"): + # Torch or Alpa path + device_idx = None + if hasattr(ary, "device"): # Torch to_device + device_idx = beam_idx.to(ary.device) + else: + device_idx = to_mesh(ary.device_mesh) + return ary.index_select(0, device_idx) + else: + # Jax path + return jax_index_select(ary, cpu_idx, 0) + return tuple( + tuple( + single_element_reorder_cache(past_state) + for past_state in layer_past) + for layer_past in past) + + +def get_hf_gpt_model(model_name, device, num_beams): raw_model = GPT2LMHeadModel.from_pretrained(model_name) raw_model = raw_model.to(device) @@ -116,6 +176,7 @@ def inference_func(input_ids, return InferenceFuncOutput(out.logits, out.past_key_values) inference_func_config = raw_model.config + inference_func_config.num_beams = num_beams transformer_config = TransformerModelConfig( H=raw_model.config.n_embd, L=raw_model.config.n_layer, @@ -127,7 +188,7 @@ def inference_func(input_ids, executable, transformer_config) -def get_hf_opt_model(model_name, device): +def get_hf_opt_model(model_name, device, num_beams): raw_model = OPTForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if "cuda" in device else torch.float32) @@ -150,7 +211,7 @@ def inference_func(input_ids, output_hidden_states=output_hidden_states) return InferenceFuncOutput(out.logits, out.past_key_values) - inference_func_config = InferenceFuncConfig() + inference_func_config = InferenceFuncConfig(num_beams=num_beams) for key in inference_func_config.__dataclass_fields__.keys(): setattr(inference_func_config, key, getattr(raw_model.config, key)) transformer_config = TransformerModelConfig( @@ -171,6 +232,7 @@ def get_model(model_name: str, dtype=jnp.float16, dummy=False, batch_size=1, + num_beams=1, decoding_length_per_step=1, num_micro_batches=1, support_output_attentions=False, @@ -191,9 +253,9 @@ def get_model(model_name: str, f"Cannot support num_micro_batches > 1 in autoregressive mode.") if "gpt" in model_name: - return get_hf_gpt_model(model_name, device) + return get_hf_gpt_model(model_name, device, num_beams) if "facebook/opt" in model_name: - return get_hf_opt_model(model_name, device) + return get_hf_opt_model(model_name, device, num_beams) assert ("jax/opt" in model_name or "alpa/opt" in model_name) name = model_name.split("-")[1].upper() @@ -220,7 +282,7 @@ def get_model(model_name: str, # load params params = load_params_np(params_aval, path, config, dummy) - init_cache = init_cache_np(config, batch_size=1) + init_cache = init_cache_np(config, batch_size=batch_size * num_beams) params, init_cache = jax.tree_map(jnp.array, (params, init_cache)) else: assert "alpa/opt" in model_name @@ -238,9 +300,11 @@ def get_model(model_name: str, seq_len=config.max_target_positions, vocab_size=config.vocab_size) + if autoregressive: + assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!" executable, params_aval = get_pipeshard_executable( config, - batch_size=batch_size, + batch_size=batch_size * num_beams, num_micro_batches=num_micro_batches, decoding_length_per_step=decoding_length_per_step, support_output_attentions=support_output_attentions, @@ -253,7 +317,7 @@ def get_model(model_name: str, if autoregressive: init_cache = init_cache_dis_array(executable, config, - 1, + batch_size * num_beams, dummy=dummy) set_skip_shard_args_check(init_cache) executable.sync() @@ -292,7 +356,7 @@ def inference_func(input_ids, return InferenceFuncOutput(logits_step, output.attention_cache, output.hidden_states, output.attentions) - inference_func_config = InferenceFuncConfig() + inference_func_config = InferenceFuncConfig(num_beams=num_beams) return WrappedInferenceFunc(inference_func, inference_func_config, executable, transformer_config) diff --git a/examples/opt_serving/textgen_demo.py b/examples/opt_serving/textgen_demo.py index 49be90a88..5c9c34101 100644 --- a/examples/opt_serving/textgen_demo.py +++ b/examples/opt_serving/textgen_demo.py @@ -7,16 +7,21 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) tokenizer.add_bos_token = False +num_beams = 1 # Load the model model = get_model(model_name="alpa/opt-2.7b", device="cuda", - path="/home/ubuntu/opt_weights/") + path="/home/ubuntu/efs/parax-proj/", + num_beams=num_beams) # Generate prompt = "Paris is the capital city of" input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) +output = model.generate(input_ids=input_ids, + max_length=256, + do_sample=True, + num_beams=num_beams) generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) print(generated_string)