Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda graph + num-scheduler-steps bug fix #236

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,17 @@ def build(self, seq_lens: List[int], query_lens: List[int],
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
Expand Down
Loading