Skip to content

Commit

Permalink
[Core][Optimization] change python dict to pytorch tensor (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and dtrifiro committed May 7, 2024
1 parent 4824f6e commit fd9fef5
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 81 deletions.
2 changes: 1 addition & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void swap_blocks(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
torch::Tensor& block_mapping);

void reshape_and_cache(
torch::Tensor& key,
Expand Down
20 changes: 5 additions & 15 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
Expand All @@ -114,26 +114,16 @@ void copy_blocks(
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;

// block_mapping is a 2D tensor with shape (num_pairs, 2).
int num_pairs = block_mapping.size(0);

// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);

// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
Expand All @@ -146,7 +136,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(),
numel_per_block);
}));
}
Expand Down
20 changes: 6 additions & 14 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ template <typename scalar_t>
void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
const torch::Tensor& mapping_pairs,
const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size();
const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second;
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset;
Expand Down Expand Up @@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(

void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}

std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}

const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
# Nothing is copied.
assert output.blocks_to_copy == {}
assert output.blocks_to_copy == []


def test_decode_swap_beam_search():
Expand Down Expand Up @@ -618,7 +618,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == expected_swap_mapping
# Nothing is copied.
assert output.blocks_to_copy == {}
assert output.blocks_to_copy == []


def test_schedule_decode_blocks_to_copy_update():
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert output.blocks_to_swap_out == {}
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == {2: [3]}
assert output.blocks_to_copy == [(2, 3)]


def test_schedule_swapped_simple():
Expand Down Expand Up @@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
assert len(remaining_swapped) == 0
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
assert output.blocks_to_copy == {2: [3]}
assert output.blocks_to_copy == [(2, 3)]


def test_scheduling_budget():
Expand Down
21 changes: 12 additions & 9 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
Expand All @@ -81,15 +82,17 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]

# Call the copy blocks kernel.
ops.copy_blocks(key_caches, value_caches, block_mapping)
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)

# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
Expand Down
2 changes: 1 addition & 1 deletion tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_swap() -> None:
seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={},
blocks_to_copy=[],
)
worker.execute_model(execute_model_req=execute_model_req)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def swap_blocks(
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def swap_blocks(
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
Expand Down
41 changes: 20 additions & 21 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.utils import merge_dicts

logger = init_logger(__name__)

Expand Down Expand Up @@ -122,8 +121,8 @@ class SchedulerOutputs:
blocks_to_swap_in: Dict[int, int]
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int]
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
Expand Down Expand Up @@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
# The blocks to swap out.
blocks_to_swap_out: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int

Expand All @@ -189,7 +188,7 @@ def create_empty(cls) -> "SchedulerRunningOutputs":
preempted=[],
swapped_out=[],
blocks_to_swap_out={},
blocks_to_copy={},
blocks_to_copy=[],
num_lookahead_slots=0,
)

Expand All @@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
# The blocks to swap in.
blocks_to_swap_in: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Infeasible sequence groups.
Expand All @@ -221,7 +220,7 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs":
decode_seq_groups=[],
prefill_seq_groups=[],
blocks_to_swap_in={},
blocks_to_copy={},
blocks_to_copy=[],
num_lookahead_slots=0,
infeasible_seq_groups=[],
)
Expand Down Expand Up @@ -394,7 +393,7 @@ def _schedule_running(
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
blocks_to_copy: List[Tuple[int, int]] = []

decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
Expand Down Expand Up @@ -511,7 +510,7 @@ def _schedule_swapped(
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
Expand Down Expand Up @@ -794,8 +793,8 @@ def _schedule_default(self) -> SchedulerOutputs:
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
Expand Down Expand Up @@ -882,8 +881,8 @@ def _schedule_chunked_prefill(self):
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
Expand Down Expand Up @@ -1011,27 +1010,27 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
block indices to lists of destination block indices. This
dictionary is updated with the new source and destination block
indices for the appended slots.
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
ints, the first int is the source block index, and the second
int is the destination block index. This list is updated with
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)

for src, dests in cows.items():
if src not in blocks_to_copy:
blocks_to_copy[src] = []
blocks_to_copy[src].extend(dests)
for dest in dests:
blocks_to_copy.append((src, dest))

def _preempt(
self,
Expand Down
Loading

0 comments on commit fd9fef5

Please sign in to comment.