Skip to content

Commit

Permalink
re-use fused_topk function
Browse files Browse the repository at this point in the history
  • Loading branch information
divakar-amd committed Oct 16, 2024
1 parent 618663d commit d47b89c
Showing 1 changed file with 3 additions and 22 deletions.
25 changes: 3 additions & 22 deletions benchmarks/kernels/benchmark_mixtral_moe_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
union_of_list_of_dicts)

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import (get_config_file_name,
from vllm.model_executor.layers.fused_moe import (fused_topk,
get_config_file_name,
invoke_fused_moe_kernel,
moe_align_block_size)

Expand Down Expand Up @@ -185,28 +186,8 @@ def run_timing(
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
topk_ = top_k
topk_weights = torch.empty(M,
topk_,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk_,
dtype=torch.int32,
device=hidden_states.device)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies # Not used. Will be used in the future.

topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, top_k, True)

intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
Expand Down

0 comments on commit d47b89c

Please sign in to comment.