From c7873aa94da682807fa126fecc8d3ce2dd78e4a8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 23 Jun 2022 12:13:13 -0700 Subject: [PATCH] Merge PlacementSpec and LoadInfo (#540) --- alpa/create_state_parallel.py | 3 +- alpa/device_mesh.py | 30 ++++---- alpa/mesh_executable.py | 10 ++- alpa/parallel_plan.py | 2 + .../pipeline_parallel/pipeshard_executable.py | 27 +++---- alpa/pipeline_parallel/runtime_emitter.py | 36 +++++---- alpa/serialization.py | 77 +++++++------------ docs/tutorials/opt_serving.rst | 64 ++++++++++----- examples/opt_serving/README.md | 56 ++++++++++---- .../benchmark/benchmark_text_gen.py | 5 +- examples/opt_serving/model/opt_model.py | 34 ++++---- examples/opt_serving/model/wrapper.py | 2 +- .../scripts/convert_to_numpy_weights.py | 14 ++-- examples/opt_serving/textgen_demo.py | 23 ++++++ tests/runtime/test_dist_save_load.py | 12 +-- 15 files changed, 226 insertions(+), 169 deletions(-) create mode 100644 examples/opt_serving/textgen_demo.py diff --git a/alpa/create_state_parallel.py b/alpa/create_state_parallel.py index fc37e55bc..6902a639f 100644 --- a/alpa/create_state_parallel.py +++ b/alpa/create_state_parallel.py @@ -9,7 +9,8 @@ from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup from alpa.mesh_executable import (NormalMeshDriverExecutable, - GradAccMeshDriverExecutable, PlacementSpec) + GradAccMeshDriverExecutable) +from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable_internal from alpa.pipeline_parallel.layer_construction import add_pipeline_marks_for_sliced_eqns from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 7d724301f..774cb3f15 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -52,6 +52,7 @@ from alpa.global_env import global_config from alpa.monkey_patch import set_override_backend from alpa.shard_parallel.auto_sharding import LogicalDeviceMesh +from alpa.parallel_plan import PlacementSpec from alpa.timer import timers from alpa.util import (benchmark_func, list_gpu_info, jax_tensor_to_cupy, cupy_to_jax_tensor, jax_tensor_set, @@ -290,9 +291,7 @@ def save_buffers(self, ckpt_dir: str, local_cache_dir: Union[str, None], for uuid in uuids: assert uuid in self.buffers - shard_names = [ - str(self.host_id) + "." + str(i) for i in range(len(uuids)) - ] + shard_names = [f"shard_{self.host_id}.{i}" for i in range(len(uuids))] metadata = { "global_shape": global_shape, @@ -313,7 +312,7 @@ def save_buffers(self, ckpt_dir: str, local_cache_dir: Union[str, None], with open(os.path.join(save_dir, shard_name), "wb") as datafile: np.save(datafile, self.buffers[uuid]) - with open(os.path.join(save_dir, f".metadata{self.host_id}"), + with open(os.path.join(save_dir, f"metadata_{self.host_id}"), "wb") as metafile: pickle.dump(metadata, metafile) @@ -325,7 +324,7 @@ def load_buffers(self, ckpt_dir: str, uuids: Sequence[int], shard_indices: Sequence[Index], device_ids: Sequence[int]): assert len(uuids) > 0 metadatas = list( - filter(lambda fname: fname.startswith(".metadata"), + filter(lambda fname: fname.startswith("metadata"), os.listdir(ckpt_dir))) # pylint: disable=import-outside-toplevel from alpa.serialization import load_sharded_array @@ -1892,26 +1891,29 @@ def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int): cg = self.collective_groups[src_mesh_id][dst_mesh_id] self._instantiate_nccl_group(cg) - def shard_args_to_arrays(self, load_infos: "LoadInfo", args: Sequence[Any]): + def shard_args_to_arrays(self, load_infos: PlacementSpec, + args: Sequence[Any]): rets = [] for info, arg in zip(load_infos, args): aval = info.aval - if info.is_replicated(): + if len(info.mesh_ids) == 1: + mesh = self.meshes[info.mesh_ids[0]] + spec = info.sharding_specs[0] + indices = pxla.spec_to_indices(aval.shape, spec) + rets.append( + mesh.shard_args_to_arrays((aval,), (indices,), (spec,), + (arg,))[0]) + else: meshes, arrays = [], [] - for mesh, spec in info.get_info(): + for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): + mesh = self.meshes[mesh_id] meshes.append(mesh) indices = pxla.spec_to_indices(aval.shape, spec) arrays.append( mesh.shard_args_to_arrays((aval,), (indices,), (spec,), (arg,))[0]) rets.append(ReplicatedDistributedArray(meshes, arrays)) - else: - mesh, spec = info.get_info() - indices = pxla.spec_to_indices(aval.shape, spec) - rets.append( - mesh.shard_args_to_arrays((aval,), (indices,), (spec,), - (arg,))[0]) return rets diff --git a/alpa/mesh_executable.py b/alpa/mesh_executable.py index 873dafedc..6edebe2d1 100644 --- a/alpa/mesh_executable.py +++ b/alpa/mesh_executable.py @@ -369,8 +369,9 @@ def launch_on_driver(self, *args, **kwargs): def get_input_placement_specs(self): """Return the preferred placement specs for input arguments.""" placement_specs = [ - PlacementSpec((self.physical_mesh.mesh_id,), (sharding_spec,)) - for sharding_spec in self.input_sharding_specs + PlacementSpec(aval, (self.physical_mesh.mesh_id,), + (sharding_spec,)) for aval, sharding_spec in zip( + self.avals, self.input_sharding_specs) ] return tree_unflatten(self.in_tree, placement_specs) @@ -822,8 +823,9 @@ def launch_on_driver(self, *args): def get_input_placement_specs(self): """Return the preferred placement specs for input arguments.""" placement_specs = [ - PlacementSpec((self.physical_mesh.mesh_id,), (sharding_spec,)) - for sharding_spec in self.global_arg_sharding_specs + PlacementSpec(aval, (self.physical_mesh.mesh_id,), + (sharding_spec,)) for aval, sharding_spec in zip( + self.avals, self.global_arg_sharding_specs) ] return tree_unflatten(self.in_tree, placement_specs) diff --git a/alpa/parallel_plan.py b/alpa/parallel_plan.py index 1fa7d9c0a..478e1d380 100644 --- a/alpa/parallel_plan.py +++ b/alpa/parallel_plan.py @@ -6,6 +6,7 @@ from typing import Sequence, Tuple import numpy as np +from jax.core import ShapedArray from jax.interpreters import pxla from jax.tree_util import PyTreeDef @@ -13,6 +14,7 @@ @dataclass class PlacementSpec: """Specify how a tensor is stored distributedly.""" + aval: ShapedArray mesh_ids: Sequence[int] sharding_specs: Sequence[pxla.ShardingSpec] diff --git a/alpa/pipeline_parallel/pipeshard_executable.py b/alpa/pipeline_parallel/pipeshard_executable.py index 73a2afa13..90ba51526 100644 --- a/alpa/pipeline_parallel/pipeshard_executable.py +++ b/alpa/pipeline_parallel/pipeshard_executable.py @@ -4,18 +4,19 @@ import time from typing import Optional, Sequence -from jax.tree_util import tree_map, tree_flatten, tree_unflatten, PyTreeDef +from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef import numpy as np import ray.exceptions from alpa.device_mesh import MeshHostWorker from alpa.global_env import global_config from alpa.device_mesh import PhysicalDeviceMeshGroup -from alpa.mesh_executable import ( - AllocZeroBufferWorkerExecutable, ConcatMeshWorkerExecutable, - MemzeroWorkerExecutable, PartialGradAccMeshWorkerExecutable, - next_mesh_executable_uuid, get_uuid_np_array, next_remote_buffer_uuid, - RemoteBufferRef, PlacementSpec) +from alpa.mesh_executable import (AllocZeroBufferWorkerExecutable, + ConcatMeshWorkerExecutable, + MemzeroWorkerExecutable, + PartialGradAccMeshWorkerExecutable, + next_mesh_executable_uuid, get_uuid_np_array, + next_remote_buffer_uuid, RemoteBufferRef) from alpa.pipeline_parallel.runtime_emitter import ( AllocateZeroWorkerExecutableConfig, ConcatWorkerExecutableConfig, ExecutableConfig, MemZeroWorkerExecutableConfig, @@ -57,7 +58,7 @@ def __init__(self, self.stages = pipeshard_config.xla_stages self.schedule = pipeshard_config.schedule self.flop_count = pipeshard_config.flop_count - self.load_info = pipeshard_config.load_info + self.input_placement_specs = pipeshard_config.input_placement_specs # List[stage_idx -> str] self.fully_optimized_hlo_texts = [] self.sharding_annotated_hlo_texts = ( @@ -204,12 +205,7 @@ def launch_on_driver(self, *args): def get_input_placement_specs(self): """Return the preferred placement specs for input arguments.""" - - def load_info_to_placement_spec(load_info): - return PlacementSpec([x.mesh_id for x in load_info.meshes], - load_info.specs) - - return tree_map(load_info_to_placement_spec, self.load_info) + return self.input_placement_specs def __call__(self, *args): """Fast call without signature matching.""" @@ -225,11 +221,6 @@ def __call__(self, *args): out = self.launch_on_driver(*args_flat) return tree_unflatten(self.out_tree, out) - ##### Load/Store Related Functions ##### - def get_load_info(self): - """Get the load info for model checkpoints.""" - return self.load_info - ##### Profiling and Debugging Related Functions ##### def get_execution_time_costs(self, warmup=2, diff --git a/alpa/pipeline_parallel/runtime_emitter.py b/alpa/pipeline_parallel/runtime_emitter.py index fd16814fd..fda751043 100644 --- a/alpa/pipeline_parallel/runtime_emitter.py +++ b/alpa/pipeline_parallel/runtime_emitter.py @@ -15,12 +15,12 @@ from alpa.device_mesh import (DistributedArray, PhysicalDeviceMeshGroup, ReplicatedDistributedArray) from alpa.mesh_executable import next_mesh_executable_uuid +from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.computation import XlaShardedPipelineComputation from alpa.pipeline_parallel.schedules import PipelineSchedule from alpa.pipeline_parallel.cross_mesh_resharding import ( CrossMeshCommunicator, SymbolicBroadcastReshardingTask, SymbolicReshardingTask, ReshardingTask) -from alpa.serialization import LoadInfo from alpa.util import (DisjointDict, OrderedSet, get_shard_shape, get_microbatch_sharding_spec, compile_concatenate) @@ -247,7 +247,7 @@ class PipeshardConfig: output_local_uuid_list: Sequence[Sequence[int]] outs_handler: Callable # Others - load_info: LoadInfo + input_placement_specs: Sequence[PlacementSpec] sharding_annotated_hlo_texts: Sequence[str] flop_count: int @@ -435,8 +435,8 @@ def compile(self): worker, used_outside, donated, instruction_lists) # Compile load info - load_info = self._compile_load_info(input_config.mesh_arg_indices, - input_shard_specs) + input_placement_specs = self._compile_input_placement_spec( + input_config.mesh_arg_indices, input_shard_specs) return PipeshardConfig( # Executable configs instruction_lists, @@ -455,7 +455,7 @@ def compile(self): output_local_uuid_list, outs_handler, # Others - load_info, + input_placement_specs, self.sharding_annotated_hlo_texts, self.flop_count) @@ -1013,23 +1013,27 @@ def outs_handler(mesh_group, bufs): return outs_handler - def _compile_load_info(self, mesh_arg_indices, input_shard_specs): + def _compile_input_placement_spec(self, mesh_arg_indices, + input_shard_specs): assert self.in_tree is not None - # build load_info_arr: flatten global index => LoadInfo object - load_info_arr = [None] * len(self.is_batch) + # build spec_arr: List[flatten global index -> PlacementSpec] + spec_arr = [None] * len(self.is_batch) for mesh_idx, physical_mesh in enumerate(self.mesh_group): for local_idx, global_idx in enumerate(mesh_arg_indices[mesh_idx]): - aval, mesh, spec = (self.global_invars[global_idx].aval, - physical_mesh, - input_shard_specs[mesh_idx][local_idx]) - if load_info_arr[global_idx] is None: - load_info_arr[global_idx] = LoadInfo(aval, [mesh], [spec]) + shard_spec = input_shard_specs[mesh_idx][local_idx] + if spec_arr[global_idx] is None: + spec_arr[global_idx] = PlacementSpec( + self.global_invars[global_idx].aval, + (physical_mesh.mesh_id,), (shard_spec,)) else: - load_info_arr[global_idx].add_replica(mesh, spec) + old_val = spec_arr[global_idx] + spec_arr[global_idx] = PlacementSpec( + old_val.aval, + old_val.mesh_ids + (physical_mesh.mesh_id,), + old_val.sharding_specs + (shard_spec,)) - # build load_info_arr - return tree_unflatten(self.in_tree, load_info_arr) + return tree_unflatten(self.in_tree, spec_arr) # TODO(yonghao): set empty buffer is not compatiable with local allgather @staticmethod diff --git a/alpa/serialization.py b/alpa/serialization.py index a3efd26f5..d64ad276a 100644 --- a/alpa/serialization.py +++ b/alpa/serialization.py @@ -6,18 +6,16 @@ import logging import os import pickle -from typing import Union, Sequence +from typing import Union from flax.serialization import to_state_dict, from_state_dict import jax -from jax.interpreters.pxla import ShardingSpec -from jax.core import ShapedArray from jax._src.tree_util import tree_flatten, tree_leaves, tree_unflatten, PyTreeDef import msgpack import numpy as np from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, - PhysicalDeviceMesh) + get_global_virtual_physical_mesh) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -39,7 +37,7 @@ def _dfs_pytree(tree, prefix): def _save_unsharded_array(ckpt_dir, arr): os.makedirs(ckpt_dir, exist_ok=True) - shard_name = "0.0" + shard_name = "shard_0.0" metadata = { "global_shape": arr.shape, "dtype": arr.dtype, @@ -48,7 +46,7 @@ def _save_unsharded_array(ckpt_dir, arr): } with open(os.path.join(ckpt_dir, shard_name), "wb") as datafile: np.save(datafile, arr) - with open(os.path.join(ckpt_dir, ".metadata0"), "wb") as metafile: + with open(os.path.join(ckpt_dir, "metadata_0"), "wb") as metafile: pickle.dump(metadata, metafile) @@ -135,70 +133,49 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike], metafile.write(msgpack.packb(metadata)) -class LoadInfo: - """ - A wrapper for the loading information. - """ - - def __init__(self, aval: ShapedArray, meshes: Sequence[PhysicalDeviceMesh], - specs: Sequence[ShardingSpec]): - assert len(meshes) == len(specs) - self.aval = aval - self.meshes = meshes - self.specs = specs - - def add_replica(self, mesh, spec): - self.meshes.append(mesh) - self.specs.append(spec) - - def get_info(self): - if self.is_replicated(): - return zip(self.meshes, self.specs) - else: - return self.meshes[0], self.specs[0] - - def is_replicated(self): - return len(self.meshes) > 1 - - def __str__(self): - return f"{self.aval}, {self.meshes[0].mesh_id}, {self.specs[0]}" - - def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int, - load_info: PyTreeDef): + placement_specs: PyTreeDef): """ Restore the specified checkpoint from `ckpt_dir` and reshard it - according to the `load_info`. + according to the `placement_specs`. Args: ckpt_dir: directory of checkpoints to restore from. If you do not have a shared filesystem, each host needs a copy of the checkpoint on its local disk at the same path. step: step number to load. - load_info: shardingSpec and deviceMesh placement info for loading. + placement_specs: shardingSpec and deviceMesh placement info + for loading. """ metapath = os.path.join(ckpt_dir, f"checkpoint_{step}") with open(metapath, "rb") as metafile: - metadata = from_state_dict(load_info, msgpack.unpackb(metafile.read())) + metadata = from_state_dict(placement_specs, + msgpack.unpackb(metafile.read())) state_paths, state_tree = tree_flatten(metadata) - flat_info = tree_leaves(load_info) + flat_info = tree_leaves(placement_specs) flat_load_state = [] + mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group + assert mesh_group is not None + for path, info in zip(state_paths, flat_info): if info is None: logger.warning("Variable is not used, skip loading it") flat_load_state.append(None) - if info.is_replicated(): + if len(info.mesh_ids) == 1: + dist_arr = DistributedArray.load(os.path.join(ckpt_dir, + path), info.aval, + mesh_group[info.mesh_ids[0]], + info.sharding_specs[0]) + flat_load_state.append(dist_arr) + else: meshes, arrays = [], [] - for mesh, spec in info.get_info(): - meshes.append(mesh) - dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), - info.aval, mesh, spec) + for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): + meshes.append(mesh_group[mesh_id]) + dist_arr = DistributedArray.load(os.path.join(ckpt_dir, + path), info.aval, + mesh_group[mesh_id], spec) arrays.append(dist_arr) flat_load_state.append(ReplicatedDistributedArray(meshes, arrays)) - else: - mesh, spec = info.get_info() - dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), - info.aval, mesh, spec) - flat_load_state.append(dist_arr) + return tree_unflatten(state_tree, flat_load_state) diff --git a/docs/tutorials/opt_serving.rst b/docs/tutorials/opt_serving.rst index aa51fc512..17220daae 100644 --- a/docs/tutorials/opt_serving.rst +++ b/docs/tutorials/opt_serving.rst @@ -3,13 +3,11 @@ Serving OPT-175B using Alpa This tutorial provides guides to setup a serving system to serve the largest available pretrained language model OPT-175B. - As a serving system, Alpa provides the following unique advantages: - **Support commodity hardware**: With Alpa, you can serve OPT-175B using your in-house GPU cluster, without needing the latest generations of A100 80GB GPUs nor fancy InfiniBand connections -- no hardware constraints! -- **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallelism strategies based on your cluster setup. - +- **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallel strategies based on your cluster setup. In this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. Specifically, Alpa provides: @@ -20,15 +18,42 @@ Specifically, Alpa provides: .. note:: - The trained OPT model weights can be obtained from `Metaseq download page `_. Usages of + The pre-trained OPT model weights can be obtained from `Metaseq download page `_. Usages of the pretrained model weights are subject to their `license `_ . .. note:: - You will need at least 350GB memory to to serve the OPT-175B model. You can also follow this guide to setup a serving system to serve smaller versions of OPT, - such as OPT-66B, OPT-30B, etc. Pick an appropriate size from `OPT weight release page `_ based on - your available resources. - + You will need at least 350GB memory to to serve the OPT-175B model. For example, you can use 4 x AWS p3.16xlarge instance, which provide 4 instance x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory. + You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc. + Pick an appropriate size from `OPT weight release page `_ based on your available resources. + +Demo +---- +Use huggingface/transformers interface and Alpa backend for distributed inference on a Ray cluster. + +.. code:: python + + from transformers import AutoTokenizer + from examples.opt_serving.model.wrapper import get_model + + # Load the tokenizer. We have to use the 30B version because + # other versions have some issues. The 30B version works for all OPT models. + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) + tokenizer.add_bos_token = False + + # Load the model + model = get_model(model_name="alpa/opt-2.7b", + device="cuda", + path="/home/ubuntu/opt_weights/") + + # Generate + prompt = "Paris is the capital city of " + + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) + generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) + + print(generated_string) Requirements ------------ @@ -57,44 +82,41 @@ There are two ways you can obtain the pretrained OPT weights. then use our script `convert_to_numpy_weight.py `_ to convert it into Alpa-compatible formats. 2. We provide links to download the preprocessed 125M and 2.7B model below. For other sizes of OPT, please join `Alpa slack `_ to request a copy from the Alpa developer team. + - `OPT-125M weights `_ - `OPT-2.7B weights `_ +Run and Benchmark Generation in The Command Line +------------------------------------------------ -Run Generation in Command Line ------------------------------- - -For a small model that can fit into one GPU, such as the OPT-125M, we can run single-GPU generation using either PyTorch backend or JAX backend. -For examples: - -1. Run generation using the 125M OPT model with PyTorch/HuggingFace backend: +- Run generation using the 125M model with PyTorch/HuggingFace backend: .. code:: bash cd benchmark python3 benchmark_text_gen.py --model facebook/opt-125m --path [PATH_TO_WEIGHT] -2. Run generation using the OPT-125M model with JAX backend in debug model to output the generated text: +- Run generation using the 125M model with JAX backend in debug model to output the generated text: .. code:: bash python3 benchmark_text_gen.py --model jax/opt-125m --path [PATH_TO_WEIGHT] --debug -3. Run model-parallel generation using the 2.7B model with Alpa: +- Run model-parallel generation using the 2.7B model with Alpa: .. code:: bash ray start --head python3 benchmark_text_gen.py --model alpa/opt-2.7b --path [PATH_TO_WEIGHT] --debug -4. Run distributed generation with the 175B model using Alpa; Note you will need >350Gb total GPU memory in the entire cluster to successfully run the inference. +- Run distributed generation with the 175B model using Alpa. Note you will need >350GB total GPU memory in the entire cluster to successfully run the inference. .. code:: bash - # Remember to start ray on the entire cluster before running the generation + # Remember to start Ray on the entire cluster before running the generation python3 benchmark_text_gen.py --model alpa/opt-175b --path [PATH_TO_WEIGHT] --debug -Launch a web server to serve the OPT models +Launch a Web Server to Serve the OPT models ------------------------------------------- Launch the web server: @@ -110,4 +132,4 @@ Then open ``https://[IP-ADDRESS]:10001`` in your browser to try out the model! License ------- -The Use of the OPT pretrained weights are subject to the `Model Licence `_ by Metaseq. +The use of the OPT pretrained weights are subject to the `Model Licence `_ by Metaseq. diff --git a/examples/opt_serving/README.md b/examples/opt_serving/README.md index 9bfbaf581..2db338150 100644 --- a/examples/opt_serving/README.md +++ b/examples/opt_serving/README.md @@ -1,15 +1,45 @@ # Examples: Serving OPT-175B using Alpa As a serving system, Alpa provides the following unique advantages: - **Support commodity hardware**: With Alpa, you can serve OPT-175B using your in-house GPU cluster, without needing the latest generations of A100 80GB GPUs nor fancy InfiniBand connections -- no hardware constraints! -- **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallelism strategies based on your cluster setup. +- **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallel strategies based on your cluster setup. In this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. - Specifically, Alpa provides: - A backend to perform model-parallel distributed inference for the large OPT models; - A web frontend to collect and batch inference requests from users. -**Note**: the trained OPT model weights can be obtained from [Metaseq](https://github.com/facebookresearch/metaseq), subject to their license. +**Note**: the pre-trained OPT model weights can be obtained from [Metaseq](https://github.com/facebookresearch/metaseq), subject to their license. +**Note**: You will need at least 350GB memory to to serve the OPT-175B model. For example, you can use 4 x AWS p3.16xlarge instance, +which provide 4 instance x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory. +You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc. +Pick an appropriate size from [OPT weight release page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) based on your available resources. + +## Demo +Use huggingface/transformers interface and Alpa backend for distributed inference on a Ray cluster. + +```python +from transformers import AutoTokenizer +from examples.opt_serving.model.wrapper import get_model + +# Load the tokenizer. We have to use the 30B version because +# other versions have some issues. The 30B version works for all OPT models. +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) +tokenizer.add_bos_token = False + +# Load the model +model = get_model(model_name="alpa/opt-2.7b", + device="cuda", + path="/home/ubuntu/opt_weights/") + +# Generate +prompt = "Paris is the capital city of " + +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) +generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) + +print(generated_string) +``` ## Requirements 1. Install Alpa following the [installation guide](https://alpa-projects.github.io/install.html). @@ -36,16 +66,15 @@ then use our script [convert_to_numpy_weight.py](scripts/convert_to_numpy_weight - [OPT-2.7B weights](https://drive.google.com/file/d/1ayIaKRhxF9osZWgcFG-3vSkjcepSWdQd/view?usp=sharing) -## Run and benchmark generation in command line +## Run and Benchmark Generation in the Command Line -For a small model, we can run single-GPU generation using either PyTorch backend or JAX backend: - -Run generation using the 125M OPT model with PyTorch/HuggingFace backend: +Run generation using the 125M model with PyTorch/HuggingFace backend: ```shell cd benchmark python3 benchmark_text_gen.py --model facebook/opt-125m ``` -Run generation using the OPT-125M model with JAX backend in debug model to output the generated text: + +Run generation using the 125M model with JAX backend in debug model to output the generated text: ```shell python3 benchmark_text_gen.py --model jax/opt-125m --path [PATH_TO_WEIGHT] --debug ``` @@ -57,14 +86,13 @@ ray start --head python3 benchmark_text_gen.py --model alpa/opt-2.7b --path [PATH_TO_WEIGHT] --debug ``` -Run distributed generation using the 175B model with Alpa as below. -Note you will need >350Gb total GPU memory in the entire cluster to successfully run the inference. +Run distributed generation with the 175B model using Alpa. Note you will need >350GB total GPU memory in the entire cluster to successfully run the inference. ```shell -# Remember to start ray on the entire cluster before running the generation +# Remember to start Ray on the entire cluster before running the generation python3 benchmark_text_gen.py --model alpa/opt-175b --path [PATH_TO_WEIGHT] --debug ``` -## Start a web server to serve the OPT models +# Launch a Web Server to Serve the OPT models Launch the web server: ```shell @@ -76,11 +104,11 @@ Then open `https://[IP-ADDRESS]:10001` in your browser to try out the model! ## Code structure -- [examples/opt_serving/benchmark](benchmark): Benchmark scripts for generation via command line. +- [examples/opt_serving/benchmark](benchmark): Benchmark scripts for generation in the command line. - [examples/opt_serving/dataset](dataset): Data loaders for serving. - [examples/opt_serving/service](service): Model serving web server. - [examples/opt_serving/generator.py](generator.py): Backend for web server. - [examples/opt_serving/interactive_hosted.py](interactive_hosted.py): Web server entry point. ## License -The Use of the OPT pretrained weights are subject to the [Model Licenc](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md) by Metaseq. \ No newline at end of file +The use of the OPT pretrained weights are subject to the [Model License](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md) by Metaseq. diff --git a/examples/opt_serving/benchmark/benchmark_text_gen.py b/examples/opt_serving/benchmark/benchmark_text_gen.py index 93f6df743..0966f35ce 100644 --- a/examples/opt_serving/benchmark/benchmark_text_gen.py +++ b/examples/opt_serving/benchmark/benchmark_text_gen.py @@ -164,6 +164,7 @@ dummy=args.dummy) load_time = time.time() - tic + # warm up input_ids = tokenizer("Paris is the capital city of", return_tensors="pt").input_ids.to(args.device) output = model.generate(input_ids=input_ids, @@ -176,9 +177,10 @@ L = model.transformer_config.L seq_len = model.transformer_config.seq_len vocab_size = model.transformer_config.vocab_size - num_gpus = alpa.get_global_cluster( ).num_devices if "alpa" in args.model else 1 + + # benchmark for i in range(n_iters): prompt = test_prompts[i] torch.manual_seed(8) @@ -238,3 +240,4 @@ f"{latency_32_tokens:.2f}" ] write_tsv(heads, values, "results.tsv") + diff --git a/examples/opt_serving/model/opt_model.py b/examples/opt_serving/model/opt_model.py index abb0199e2..f592771a2 100644 --- a/examples/opt_serving/model/opt_model.py +++ b/examples/opt_serving/model/opt_model.py @@ -738,7 +738,6 @@ def inference_step(params, batch): }) executable.dump_debug_info("tmp") - return executable, params @@ -827,14 +826,14 @@ def load_param(param_key, loaded_array): def load_params_dis_array(path, executable, params_aval, config, dummy=False): if dummy: alpa.global_config.use_dummy_value_for_benchmarking = True - params_info, _ = executable.get_load_info() + params_info, _ = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(params_aval) flat_info = tree_leaves(params_info) ret = executable.mesh_group.shard_args_to_arrays(flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret - params_info, _ = executable.get_load_info() + params_info, _ = executable.get_input_placement_specs() prefix_to_flat_idx = {} ct = itertools.count() @@ -857,16 +856,29 @@ def dfs(dict_tree, result_dict, cur_prefix): flat_mesh_ids = [] flat_arrays = [] + mesh_group = executable.mesh_group + for info in flat_infos: aval = info.aval - if info.is_replicated(): + if len(info.mesh_ids) == 1: + mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] + indices = pxla.spec_to_indices(aval.shape, spec) + buf_refs, buf_uuids = create_remote_buffer_refs(mesh, 1) + flat_shapes.append([aval.shape]) + flat_uuids.append([buf_uuids]) + flat_indices.append([indices]) + flat_mesh_ids.append([mesh.mesh_id]) + flat_arrays.append( + DistributedArray(mesh, aval, spec, buf_refs, indices)) + else: tmp_shapes = [] tmp_uuids = [] tmp_indices = [] tmp_mesh_ids = [] tmp_arrays = [] tmp_meshes = [] - for mesh, spec in info.get_info(): + for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): + mesh = mesh_group[mesh_id] indices = pxla.spec_to_indices(aval.shape, spec) buf_refs, buf_uuids = create_remote_buffer_refs(mesh, 1) array = DistributedArray(mesh, aval, spec, buf_refs, indices) @@ -882,16 +894,6 @@ def dfs(dict_tree, result_dict, cur_prefix): flat_mesh_ids.append(tuple(tmp_mesh_ids)) flat_arrays.append( ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) - else: - mesh, spec = info.get_info() - indices = pxla.spec_to_indices(aval.shape, spec) - buf_refs, buf_uuids = create_remote_buffer_refs(mesh, 1) - flat_shapes.append([aval.shape]) - flat_uuids.append([buf_uuids]) - flat_indices.append([indices]) - flat_mesh_ids.append([mesh.mesh_id]) - flat_arrays.append( - DistributedArray(mesh, aval, spec, buf_refs, indices)) for m in executable.mesh_group.meshes: for w in m.workers: @@ -906,7 +908,7 @@ def dfs(dict_tree, result_dict, cur_prefix): def init_cache_dis_array(executable, config, batch_size, dummy=False): alpa.global_config.use_dummy_value_for_benchmarking = dummy cache = init_cache_np(config, batch_size) - _, batch_info = executable.get_load_info() + _, batch_info = executable.get_input_placement_specs() cache_info = batch_info["cache"] flat_args, in_tree = tree_flatten(cache) flat_info = tree_leaves(cache_info) diff --git a/examples/opt_serving/model/wrapper.py b/examples/opt_serving/model/wrapper.py index d19b27ecd..a6628d0f9 100644 --- a/examples/opt_serving/model/wrapper.py +++ b/examples/opt_serving/model/wrapper.py @@ -167,7 +167,7 @@ def inference_func(input_ids, def get_model(model_name, device, path, - autoregressive, + autoregressive=True, dtype=jnp.float16, dummy=False, batch_size=1, diff --git a/examples/opt_serving/scripts/convert_to_numpy_weights.py b/examples/opt_serving/scripts/convert_to_numpy_weights.py index d9505fdd5..6af206eb0 100644 --- a/examples/opt_serving/scripts/convert_to_numpy_weights.py +++ b/examples/opt_serving/scripts/convert_to_numpy_weights.py @@ -1,7 +1,7 @@ """Convert Metaseq's OPT model weights into Alpa numpy weights.""" -import numpy as np import os +import numpy as np from metaseq.file_io import torch_load_cpu @@ -16,14 +16,14 @@ def save_numpy(weight_dict, to_folder): np.save(g, t) -def worker_main(): - PATH = "/home/ubuntu/parax-efs/pycharm/opt/opt_metaseq_30000m/model/" +def worker_main(src_folder, dst_folder): # Path to the single - consolidated_weight = os.path.join(PATH, "restored.pt") + consolidated_weight = os.path.join(src_folder, "restored.pt") state = torch_load_cpu(consolidated_weight) - to_folder = "/home/ubuntu/parax-efs/pycharm/opt/raw_weights/30B_resharded/" - save_numpy(state["model"], to_folder) + save_numpy(state["model"], dst_folder) if __name__ == "__main__": - worker_main() + src_folder = "/home/ubuntu/parax-efs/pycharm/opt/opt_metaseq_30000m/model/" + dst_folder = "/home/ubuntu/parax-efs/pycharm/opt/raw_weights/30B_resharded/" + worker_main(src_folder, dst_folder) diff --git a/examples/opt_serving/textgen_demo.py b/examples/opt_serving/textgen_demo.py new file mode 100644 index 000000000..cc5a7cc2d --- /dev/null +++ b/examples/opt_serving/textgen_demo.py @@ -0,0 +1,23 @@ +"""Use huggingface/transformers interface and Alpa backend for distributed inference.""" +from transformers import AutoTokenizer +from examples.opt_serving.model.wrapper import get_model + +# Load the tokenizer. We have to use the 30B version because +# other versions have some issues. The 30B version works for all OPT models. +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) +tokenizer.add_bos_token = False + +# Load the model +model = get_model(model_name="alpa/opt-2.7b", + device="cuda", + path="/home/ubuntu/opt_weights/") + +# Generate +prompt = "Paris is the capital city of " + +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) +generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) + +print(generated_string) + diff --git a/tests/runtime/test_dist_save_load.py b/tests/runtime/test_dist_save_load.py index dbe729358..b4babfe30 100644 --- a/tests/runtime/test_dist_save_load.py +++ b/tests/runtime/test_dist_save_load.py @@ -154,8 +154,8 @@ def test_jax_mlp_save_dist_load(self): executable = parallel_train_step.get_executable(jax_state, batch) # Restore checkpoint - state_ss, _ = executable.get_load_info() - load_state = restore_checkpoint(ckpt_dir, 1, state_ss) + state_ps, _ = executable.get_input_placement_specs() + load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(jax_state, batch)[0] @@ -201,8 +201,8 @@ def test_distributed_mlp_uncached_save_load(self): save_checkpoint(ckpt_dir, parallel_state, 1) # Restore checkpoint - state_ss, _ = executable.get_load_info() - load_state = restore_checkpoint(ckpt_dir, 1, state_ss) + state_ps, _ = executable.get_input_placement_specs() + load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(serial_state, batch)[0] @@ -269,8 +269,8 @@ def test_distributed_bert_cached_save_load(self): executable.sync_move_workers() # Restore checkpoint - state_ss, _ = executable.get_load_info() - load_state = restore_checkpoint(ckpt_dir, 1, state_ss) + state_ps, _ = executable.get_input_placement_specs() + load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(serial_state, batch)[0]