diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index f70cb5df..1f6d4a77 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -141,10 +141,15 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context) -> genera batch_id = 0 if batch is not None: for_concat = len(self.cache) > 0 - # Prefill and generate first token - output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token( - batch, first=True, for_concat=for_concat, - ) + try: + # Prefill and generate first token + output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token( + batch, first=True, for_concat=for_concat, + ) + except: + self._free_paged_sequences(batch, None) + raise + if hasattr(batch, "past_key_values"): clean_attribute("past_key_values", batch.past_key_values) if not is_healthcheck: @@ -206,7 +211,12 @@ async def NextToken(self, request: generate_pb2.NextTokenRequest, context) -> ge # Ensure batches are garbage-collected post-concatenation del batches - output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch) + try: + output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch) + except: + self._free_paged_sequences(batch, None) + raise + self.cache.set(batch) return generate_pb2.NextTokenResponse( diff --git a/server/text_generation_server/utils/paged.py b/server/text_generation_server/utils/paged.py index fbf6d0f5..392109ef 100644 --- a/server/text_generation_server/utils/paged.py +++ b/server/text_generation_server/utils/paged.py @@ -169,7 +169,12 @@ def prepare_inputs_with_speculation( child_sequence_ids_flattened.extend(child_sequence_ids) # add n_adds tokens to each candidate - cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened) + try: + cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened) + except: + kv_cache_manager.free_sequences(child_sequence_ids_flattened) + raise + position_ids = cache_data.position_ids # Get candidate set of speculations