Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Bad words sampling parameter #9717

Merged
merged 49 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
feabb97
add bad words logits processor
Alvant Jun 27, 2024
f80b7d6
rename processor, add docstring
Alvant Jun 27, 2024
a152e66
add test for bad words
Alvant Jun 27, 2024
9f21847
add tests (versions with and without vllm runner)
Alvant Jun 29, 2024
c577c47
keep only vllm runner tests
Alvant Jun 29, 2024
b372c17
Merge branch 'main' into feature/bad_words_ids
Alvant Jun 29, 2024
e6bdadb
run yapf and ruff
Alvant Jun 29, 2024
a69329f
run yapf and ruff again
Alvant Jun 29, 2024
b9f8a5d
search for bad sequence among all generateds
Alvant Jun 29, 2024
ecefbf4
run yapf and ruff
Alvant Jun 29, 2024
05671cd
fix test for two bad tokens
Alvant Jun 29, 2024
44bd494
fix unused imports and vars
Alvant Jun 29, 2024
0fc7974
fix style
Alvant Jun 29, 2024
e1a18f1
fix format
Alvant Jun 29, 2024
5ab80a4
refine test for two token word
Alvant Jun 29, 2024
f838046
bad_words_ids -> bad_words (sync engine)
Alvant Jul 3, 2024
32f2e20
fix bad words and test
Alvant Jul 3, 2024
f93c4d7
fix for llama tokenizer, add llama in test
Alvant Jul 4, 2024
16c2dd4
run yapf and ruff
Alvant Jul 4, 2024
ad0d61c
add comment about two models
Alvant Jul 4, 2024
3389770
Merge pull request #1 from compressa-ai/feature/bad_words
Alvant Jul 4, 2024
9b1a1ac
fix style
Alvant Jul 4, 2024
924ed79
Merge pull request #2 from compressa-ai/feature/bad_words
Alvant Jul 4, 2024
ea4e02a
clarify comment about prefixes
Alvant Jul 4, 2024
be5d5c3
Merge branch 'main' into feature/bad_words_ids
Alvant Jul 10, 2024
bd86123
move logits stuff to logits file
Alvant Jul 10, 2024
308e76a
run yapf and ruff, fix import order
Alvant Jul 10, 2024
aae2ac5
clarify add prefix logic
Alvant Jul 10, 2024
9724dbc
fix is match ckeck for different type sequences
Alvant Jul 11, 2024
1f0938c
add process params to async engine
Alvant Jul 11, 2024
f2673a5
change type to tuple in bad words ids processor
Alvant Jul 11, 2024
8a6e88b
fix type for logits process to pass checks
Alvant Jul 11, 2024
2f2ea06
Merge branch 'main' into feature/bad_words_ids
Alvant Oct 24, 2024
7c0c60c
fix code style
Alvant Oct 24, 2024
136937b
init bad words as empty list, add in repr
Alvant Oct 26, 2024
cdfce02
fix bad words post init
Alvant Oct 26, 2024
3478d6d
fix code len style
Alvant Oct 26, 2024
0187b3d
fix yapf style
Alvant Oct 26, 2024
8eeb5ad
remove async preproc params
Alvant Oct 26, 2024
4de1e23
move bad words creation logic to build_logits_processors
Alvant Oct 26, 2024
f0fbadd
fix style
Alvant Oct 26, 2024
b185e5b
unify logit processor creation logic (add getter for bad words)
Alvant Oct 26, 2024
0232f72
fix import for openai logits
Alvant Oct 26, 2024
e37fa23
fix style (ruff or other)
Alvant Oct 26, 2024
eed58a7
move all bad words logits processor creation to separate file
Alvant Oct 26, 2024
4c6f54d
fix style (one of them)
Alvant Oct 26, 2024
64a3b80
handle case of mistral tokenizer in bad words logits
Alvant Oct 26, 2024
7edca9e
simplify bad words code a bit
Alvant Oct 26, 2024
1828554
Merge branch 'main' into feature/bad_words_ids2
Alvant Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/samplers/test_no_bad_words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Make sure bad_words works.

Run `pytest tests/samplers/test_no_bad_words.py`.

"""
from typing import List, Optional

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams


def _generate(
model: LLM,
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
bad_words: Optional[List[str]] = None,
) -> List[int]:
sampling_params = SamplingParams(
temperature=temperature,
bad_words=bad_words,
)

# [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params)

output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output
# [0] token_ids (not text)
# [0] first (and only) output completion

return output_token_ids


class TestOneTokenBadWord:
MODEL = "TheBloke/Llama-2-7B-fp16"

PROMPT = "Hi! How are"
TARGET_TOKEN = "you"

def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)

self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id = self._encode(self.TARGET_TOKEN,
add_special_tokens=False)[0]

def test_one_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[0] == self.target_token_id

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN])
assert self.target_token_id not in output_token_ids

def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)

def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids


class TestTwoTokenBadWord:
# Another model (with a different tokenizer behaviour)
MODEL = "openai-community/gpt2"

PROMPT = "How old are you? I am 10"
TARGET_TOKEN1 = "years"
TARGET_TOKEN2 = "old"
NEIGHBOUR_TOKEN2 = "older"

def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)

self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
add_special_tokens=False)[0]
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
add_special_tokens=False)[0]
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
add_special_tokens=False)[0]

def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2
]

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN1])
assert self.target_token_id1 not in output_token_ids

output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN2])
assert output_token_ids[0] == self.target_token_id1
assert self.target_token_id2 not in output_token_ids

output_token_ids = self._generate(
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
# Model dependent behaviour
assert output_token_ids[:2] == [
self.target_token_id1, self.neighbour_token_id2
]

output_token_ids = self._generate(
llm,
bad_words=[
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
assert output_token_ids[:2] != [
self.target_token_id1, self.neighbour_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.neighbour_token_id2])
assert ((self.target_token_id2 in output_token_ids)
or (self.neighbour_token_id2 in output_token_ids))

def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)

@staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
searched = False

for start in range(len(sequence)):
end = start + len(subsequence)
current_subsequence = sequence[start:end]

if len(current_subsequence) < len(subsequence):
continue

searched = True

assert len(current_subsequence) == len(subsequence)

if current_subsequence == subsequence:
return True

assert searched, "All subsequences did not match in length..."

return False

def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
13 changes: 11 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
Expand Down Expand Up @@ -1963,6 +1965,7 @@ def _build_logits_processors(
logits_processors field. Returns the modified sampling params."""

logits_processors = []

if (guided_decoding := sampling_params.guided_decoding) is not None:

logger.debug(
Expand All @@ -1984,7 +1987,7 @@ def _build_logits_processors(
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)

processors = get_logits_processors(
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
Expand All @@ -1994,6 +1997,12 @@ def _build_logits_processors(
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None

if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)

if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
Expand Down
119 changes: 119 additions & 0 deletions vllm/logits_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Callable, List, Tuple, Union

import torch

from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved


def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
bad_words_ids: List[List[int]] = list()

for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)

# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != bad_words_ids[-1][0]
and len(prompt_token_ids) == len(bad_words_ids[-1])):
bad_words_ids.append(prompt_token_ids)

return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]


class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0

def __init__(self, bad_words_ids: List[List[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None

def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
self._init_word_bias(logits=logits)

last_token_bias = torch.zeros_like(logits)

for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1: # 1-token words already processed
continue

if len(bad_word_ids) > len(past_tokens_ids) + 1:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
continue

prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:]
expected_prefix = bad_word_ids[:prefix_length]

assert len(actual_prefix) == len(expected_prefix)

is_match = tuple(actual_prefix) == tuple(expected_prefix)
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
else self._NEUTRAL_LOGIT)

logits = logits + self.word_bias + last_token_bias

return logits

def _init_word_bias(self, logits: torch.FloatTensor) -> None:
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py

vocab_size = logits.shape[-1]

self._check_token_ids_bounds(vocab_size=vocab_size)

self.word_bias = torch.zeros((vocab_size, ),
dtype=torch.float,
device=logits.device)

for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1:
bad_word_id = bad_word_ids[-1]
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT

def _check_token_ids_bounds(self, vocab_size: int) -> None:
invalid_token_ids = []

for bad_word_ids in self.bad_words_ids:
for token_id in bad_word_ids:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)

if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {vocab_size},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id < {vocab_size}.")
3 changes: 2 additions & 1 deletion vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams


async def get_guided_decoding_logits_processor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from transformers import PreTrainedTokenizerBase

from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams


def get_local_lm_format_enforcer_guided_decoding_logits_processor(
Expand Down
Loading