Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao committed Oct 27, 2024
1 parent 0068133 commit 5250a65
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions tests/models/decoder_only/language/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.platforms import current_platform

from ...utils import check_outputs_equal
from ...utils import check_logprobs_close, check_outputs_equal

MODELS = [
"meta-llama/Llama-2-7b-hf",
Expand Down Expand Up @@ -43,18 +43,40 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
if model == "openbmb/MiniCPM3-4B":
# the output becomes slightly different when upgrading to
# pytorch 2.5 . Changing to logprobs checks instead of exact
# output checks.
NUM_LOG_PROBS = 8
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, NUM_LOG_PROBS)

with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
else:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
Expand Down

0 comments on commit 5250a65

Please sign in to comment.