Skip to content

Commit

Permalink
Merge branch 'main' into ransmith_int8_symmetric
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmith committed Oct 18, 2024
2 parents 90a3e0f + 1658370 commit d1964bc
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -o pipefail
echo "--- Confirming Clean Initial State"
while true; do
sleep 3
if grep -q clean ${BUILDKITE_META_DATA_RESET_TARGET}; then
if grep -q clean ${BUILDKITE_AGENT_META_DATA_RESET_TARGET}; then
echo "GPUs state is \"clean\""
break
fi
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ RUN cd vllm \
&& python3 -m pip install -r requirements-rocm.txt \
&& python3 setup.py clean --all \
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
&& python3 setup.py bdist_wheel --dist-dir=dist
&& SCCACHE_IDLE_TIMEOUT=1800 python3 setup.py bdist_wheel --dist-dir=dist
# Build gradlib
RUN cd vllm/gradlib \
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
Expand Down
68 changes: 51 additions & 17 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import async_tensor_h2d, is_hip, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
Expand Down Expand Up @@ -218,9 +218,18 @@ def build(self, seq_lens: List[int], query_lens: List[int],
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
Expand Down Expand Up @@ -325,11 +334,19 @@ def graph_capture_get_metadata_for_batch(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)

return attn_metadata

Expand All @@ -345,11 +362,19 @@ def get_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers

def prepare_graph_input_buffers(
Expand All @@ -364,11 +389,20 @@ def prepare_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

if is_hip():
assert (
self.runner.attn_backend.get_name() == "rocm-flash-attn"
), (f"Expected attn_backend name to be 'rocm-flash-attn', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
else:
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

def begin_forward(self, model_input) -> None:
return
Expand Down
70 changes: 43 additions & 27 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig
Expand Down Expand Up @@ -82,33 +83,45 @@ def __init__(

# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
weight_name: str, param_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
if weight_name.endswith("w1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 0] = loaded_weight
else:
param_data = loaded_weight
if weight_name.endswith("v1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, shard_size:2 *
shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 1] = loaded_weight
else:
param_data[:] = loaded_weight
if weight_name.endswith("w2"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
else:
param_data[:] = loaded_weight


class DbrxMoE(nn.Module):
Expand Down Expand Up @@ -409,13 +422,13 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

expert_params_mapping = [(
"w13_" if weight_name in ["w1", "v1"] else "w2_",
f"mlp.{weight_name}.",
"w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name.endswith(("w1", "v1", "w2")):
name = name + ".weight"
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand All @@ -424,11 +437,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
weight_loader(param, loaded_weight, weight_name, name)
break
else:
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit d1964bc

Please sign in to comment.