Skip to content

Commit

Permalink
Free blocks in KVCacheManager upon error (IBM#96)
Browse files Browse the repository at this point in the history
#### Motivation

We are see pods with spec. decoding getting restarted in BAM due to
health checks failing. Upon inspection of the logs, it looks like we are
running out of blocks, and never recovering from it.
#### Modifications

I added a simple check that if something goes wrong when generating a
token, we free the blocks associated with that batch. I also had to
ensure that the we free the child sequences that get created during
speculation if something goes wrong there too.

#### Result

I've verified this allow us to recover from failures related to running
out of blocks. Hopefully after this fix, we don't see the inference
server getting restarted.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep authored May 14, 2024
1 parent 0734973 commit fb23def
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
20 changes: 15 additions & 5 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,15 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context) -> genera
batch_id = 0
if batch is not None:
for_concat = len(self.cache) > 0
# Prefill and generate first token
output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token(
batch, first=True, for_concat=for_concat,
)
try:
# Prefill and generate first token
output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token(
batch, first=True, for_concat=for_concat,
)
except:
self._free_paged_sequences(batch, None)
raise

if hasattr(batch, "past_key_values"):
clean_attribute("past_key_values", batch.past_key_values)
if not is_healthcheck:
Expand Down Expand Up @@ -206,7 +211,12 @@ async def NextToken(self, request: generate_pb2.NextTokenRequest, context) -> ge
# Ensure batches are garbage-collected post-concatenation
del batches

output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch)
try:
output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch)
except:
self._free_paged_sequences(batch, None)
raise

self.cache.set(batch)

return generate_pb2.NextTokenResponse(
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ def prepare_inputs_with_speculation(
child_sequence_ids_flattened.extend(child_sequence_ids)

# add n_adds tokens to each candidate
cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened)
try:
cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened)
except:
kv_cache_manager.free_sequences(child_sequence_ids_flattened)
raise

position_ids = cache_data.position_ids

# Get candidate set of speculations
Expand Down

0 comments on commit fb23def

Please sign in to comment.