diff --git a/csrc/cache.h b/csrc/cache.h index 4c142ce17f1b9..10871b3670bac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -13,7 +13,7 @@ void swap_blocks( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping); + torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 42f884c76c620..1e02f7fcbae4c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -114,17 +114,9 @@ void copy_blocks( key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } - } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. @@ -132,8 +124,6 @@ void copy_blocks( key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -146,7 +136,7 @@ void copy_blocks( vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); })); } diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 7849a5df991b1..95e3f11900fde 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -8,16 +8,16 @@ template void copy_blocks_cpu_impl( std::vector &key_caches, std::vector &value_caches, - const std::vector> mapping_pairs, + const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { - const size_t pair_num = mapping_pairs.size(); + const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = - element_num_per_block * mapping_pairs[pair].second; + element_num_per_block * mapping_pairs[pair][1].item(); scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t *target_ptr = key_cache_ptr + target_offset; @@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - const std::map> &block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } - std::vector> mapping_pairs; - mapping_pairs.reserve(block_mapping.size()); - for (const auto &pair : block_mapping) { - for (const auto &dst : pair.second) { - mapping_pairs.emplace_back(pair.first, dst); - } - } - const int element_num_per_block = key_caches[0][0].numel(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) }); diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 1358dffec8104..348169035ae97 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -568,7 +568,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == {} # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_decode_swap_beam_search(): @@ -618,7 +618,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == expected_swap_mapping # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_schedule_decode_blocks_to_copy_update(): @@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_swap_out == {} # Since append_slot returns the source -> dist mapping, it should # applied. - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_schedule_swapped_simple(): @@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy(): assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_scheduling_budget(): diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index ca215bb75837a..94a577139596e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -63,12 +63,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + block_mapping = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -81,15 +82,17 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - ops.copy_blocks(key_caches, value_caches, block_mapping) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 07bcd343a96a6..4d2d3add27d59 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -59,7 +59,7 @@ def test_swap() -> None: seq_group_metadata_list=[], blocks_to_swap_in={}, blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}, + blocks_to_copy=[], ) worker.execute_model(execute_model_req=execute_model_req) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 61c9c81d8a7b8..b2b6e7ac810e3 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -42,7 +42,7 @@ def swap_blocks( @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fc7501ed5e91f..da672d5df6161 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -48,7 +48,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ab4b1f12ee36..2851cbe2396b2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -48,7 +48,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c411b3971b8f1..c3b522e63b4b8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -46,7 +46,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f75a279086a26..03825f6023f4c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -44,7 +44,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 60f6d43f2eaa4..4c7fa71a2c78e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -49,7 +49,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 00a0f10c0950b..6f7fd51c774f8 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -209,7 +209,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a9e0b05b8db67..de3ecd24e52db 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.utils import merge_dicts logger = init_logger(__name__) @@ -122,8 +121,8 @@ class SchedulerOutputs: blocks_to_swap_in: Dict[int, int] # Blocks to swap out. Dict of GPU -> CPU block number. blocks_to_swap_out: Dict[int, int] - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. @@ -177,7 +176,7 @@ class SchedulerRunningOutputs: # The blocks to swap out. blocks_to_swap_out: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int @@ -189,7 +188,7 @@ def create_empty(cls) -> "SchedulerRunningOutputs": preempted=[], swapped_out=[], blocks_to_swap_out={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, ) @@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs: # The blocks to swap in. blocks_to_swap_in: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int # Infeasible sequence groups. @@ -221,7 +220,7 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": decode_seq_groups=[], prefill_seq_groups=[], blocks_to_swap_in={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, infeasible_seq_groups=[], ) @@ -394,7 +393,7 @@ def _schedule_running( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -511,7 +510,7 @@ def _schedule_swapped( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() @@ -794,8 +793,8 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -882,8 +881,8 @@ def _schedule_chunked_prefill(self): num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), @@ -1011,17 +1010,18 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _append_slots( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: List[Tuple[int, int]], ) -> None: """Appends new slots to the sequences in the given sequence group. Args: seq_group (SequenceGroup): The sequence group containing the sequences to append slots to. - blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source - block indices to lists of destination block indices. This - dictionary is updated with the new source and destination block - indices for the appended slots. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) @@ -1029,9 +1029,8 @@ def _append_slots( cows = self.block_manager.append_slots(seq, num_lookahead_slots) for src, dests in cows.items(): - if src not in blocks_to_copy: - blocks_to_copy[src] = [] - blocks_to_copy[src].extend(dests) + for dest in dests: + blocks_to_copy.append((src, dest)) def _preempt( self, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b539a7beedbfe..817bd6d812e48 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -203,6 +203,9 @@ def broadcast_tensor_dict( group=metadata_group) async_handles = [] for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue async_handles.append( torch.distributed.broadcast(tensor, src=src, @@ -224,6 +227,10 @@ def broadcast_tensor_dict( tensor = torch.empty(value.size, dtype=value.dtype, device="cuda") + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue async_handle = torch.distributed.broadcast(tensor, src=src, async_op=True, diff --git a/vllm/sequence.py b/vllm/sequence.py index f2939eff7959b..b486d1fedebd3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,7 @@ import copy import enum from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest @@ -745,8 +745,8 @@ class ExecuteModelRequest: blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) # Blocks to swap out. Dict of GPU -> CPU block number. blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee0648626b..26a60c652b6f4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -77,7 +77,7 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4420d4cc9e12f..e1ef500ac07b8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -248,9 +248,9 @@ def _init_cache_engine(self) -> None: def cache_copy( self, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -269,6 +269,9 @@ def execute_model( num_seq_groups: int = len(seq_group_metadata_list) assert execute_model_req is not None blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4add36e94f723..538332ad003f1 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -197,7 +197,7 @@ def cache_swap( self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. # TODO(woosuk): Profile swapping overhead and optimize if needed. @@ -205,7 +205,7 @@ def cache_swap( self.cache_engine.swap_in(blocks_to_swap_in) if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -225,7 +225,9 @@ def execute_model( num_seq_groups = len(seq_group_metadata_list) blocks_to_swap_in = execute_model_req.blocks_to_swap_in blocks_to_swap_out = execute_model_req.blocks_to_swap_out - blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in,