You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.31
Python version: 3.11.10 (main, Oct 3 2024, 07:29:13) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
Nvidia driver version: 555.42.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 1
Core(s) per socket: 16
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7313 16-Core Processor
Stepping: 1
CPU MHz: 3517.887
BogoMIPS: 5988.81
Virtualization: AMD-V
L1d cache: 1 MiB
L1i cache: 1 MiB
L2 cache: 16 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-3
NUMA node1 CPU(s): 4-7
NUMA node2 CPU(s): 8-11
NUMA node3 CPU(s): 12-15
NUMA node4 CPU(s): 16-19
NUMA node5 CPU(s): 20-23
NUMA node6 CPU(s): 24-27
NUMA node7 CPU(s): 28-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sani
tization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP di$
abled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmo
v pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_t
sc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 p
cid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_lega
cy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb b
pext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmc
all fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni x
saveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xs
aveerptr wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid
decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpcl
mulqdq rdpid overflow_recov succor smca fsrm
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0+cu121
[pip3] torchvision==0.19.0+cu121
[pip3] transformers==4.46.0.dev0
[pip3] transformers-stream-generator==0.0.4
[pip3] triton==3.0.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-ml-py 12.560.30 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] pyzmq 26.2.0 py311h7deb3e3_3 conda-forge
[conda] torch 2.4.0+cu121 pypi_0 pypi
[conda] torchvision 0.19.0+cu121 pypi_0 pypi
[conda] transformers 4.46.0.dev0 pypi_0 pypi
[conda] transformers-stream-generator 0.0.4 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 NIC0 NIC1 NIC2 NIC3 CPU Affinity NUMA Affinity G
PU NUMA ID
GPU0 X NV12 NV12 NV12 SYS PXB SYS SYS 12-15 3 N
/A
GPU1 NV12 X NV12 NV12 PXB SYS SYS SYS 4-7 1 N
/A
GPU2 NV12 NV12 X NV12 PXB SYS SYS SYS 4-7 1 N
/A
GPU3 NV12 NV12 NV12 X SYS SYS SYS PXB 28-31 7 N
/A
NIC0 SYS PXB PXB SYS X SYS SYS SYS
NIC1 PXB SYS SYS SYS SYS X SYS SYS
NIC2 SYS SYS SYS SYS SYS SYS X SYS
NIC3 SYS SYS SYS PXB SYS SYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QP
I/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within
a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
Model Input Dumps
No response
🐛 Describe the bug
I get incoherent generation outputs when using offline vLLM for inference with videos. This happens both when using URL or local paths, with 7B or 72B model, with or without tensor parallelism. The setup works well (provides coherent answers) when providing also text or text+image, but not video. This are also very different from the generated outputs when using transformers with the same arguments.
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=MODEL_PATH,
limit_mm_per_prompt={"image": 10, "video": 10},
# tensor_parallel_size=4,
tensor_parallel_size=1,
)
sampling_params = SamplingParams(
temperature=0.1,
top_p=0.001,
repetition_penalty=1.05,
max_tokens=256,
stop_token_ids=[],
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{
"type": "video",
"video": "https://ptchallenge-workshop.github.io/media/vis.mp4",
"min_pixels": 224 * 224,
# "max_pixels": 1280 * 28 * 28,
"total_pixels": 16384 * 28 * 28,
"fps": 2.0,
},
{"type": "text", "text": "Describe the video."},
],
},
]
# For video input, you can pass following values instead:
# "type": "video",
# "video": "<video URL>",
processor = AutoProcessor.from_pretrained(MODEL_PATH)
prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
mm_data = {}
if image_inputs is not None:
mm_data["image"] = image_inputs
if video_inputs is not None:
mm_data["video"] = video_inputs
llm_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text
print("#"*50 + "\n" + "Qwen repo Video url output with total_pixels:", generated_text)
with output:
INFO 10-26 20:24:11 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='Qwen/Qwen2-VL-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-VL-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2-VL-7B-Instruct, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
[rank0]:[W1026 20:24:28.958384079 ProcessGroupGloo.cpp:712] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
INFO 10-26 20:24:28 model_runner.py:1056] Starting to load model Qwen/Qwen2-VL-7B-Instruct...
INFO 10-26 20:24:30 weight_utils.py:243] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 20% Completed | 1/5 [00:01<00:07, 1.96s/it]
Loading safetensors checkpoint shards: 40% Completed | 2/5 [00:06<00:09, 3.27s/it]
Loading safetensors checkpoint shards: 60% Completed | 3/5 [00:10<00:07, 3.56s/it]
Loading safetensors checkpoint shards: 80% Completed | 4/5 [00:13<00:03, 3.64s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:24<00:00, 6.19s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:24<00:00, 4.90s/it]
INFO 10-26 20:24:56 model_runner.py:1067] Loading model weights took 15.5083 GB
INFO 10-26 20:25:06 gpu_executor.py:122] # GPU blocks: 56587, # CPU blocks: 4681
INFO 10-26 20:25:06 gpu_executor.py:126] Maximum concurrency for 32768 tokens per request: 27.63x
INFO 10-26 20:25:09 model_runner.py:1395] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 10-26 20:25:09 model_runner.py:1399] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 10-26 20:25:34 model_runner.py:1523] Graph capturing finished in 25 secs.
Processed prompts: 100%|█| 1/1 [00:05<00:00, 5.39s/it, est. speed input: 2541.48 toks/s, output
##################################################
Qwen repo Video url output with total_pixels: The: in helpful photo helpful image: Image helpful helpful
For transformers the code is the default shown in the Qwen repo, which is indeed very similar. I tried to check through other issues and commits, and from my understanding this feature is supported, and the only difference in implementations seem to be minimal (#8408 (comment))
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Your current environment
The output of `python collect_env.py`
Model Input Dumps
No response
🐛 Describe the bug
I get incoherent generation outputs when using offline vLLM for inference with videos. This happens both when using URL or local paths, with 7B or 72B model, with or without tensor parallelism. The setup works well (provides coherent answers) when providing also text or text+image, but not video. This are also very different from the generated outputs when using transformers with the same arguments.
The code below follows the example on the Qwen repo (https://github.com/QwenLM/Qwen2-VL?tab=readme-ov-file#inference-locally), but is also what seems to be recommended in vLLM docs
with output:
For transformers the code is the default shown in the Qwen repo, which is indeed very similar. I tried to check through other issues and commits, and from my understanding this feature is supported, and the only difference in implementations seem to be minimal (#8408 (comment))
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: