Skip to content

Commit

Permalink
Fix numpy and XGMI 1-hop detection (opendatahub-io#67)
Browse files Browse the repository at this point in the history
* Fix 1-hop XGMI detection

* Fix numpy versioning
  • Loading branch information
mawong-amd authored Jun 25, 2024
1 parent 17e6307 commit 3e7b0b6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake >= 3.21
ninja # For faster builds.
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy
numpy < 2.0.0
requests
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
Expand Down
39 changes: 18 additions & 21 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

try:
if is_hip():
from amdsmi import (AmdSmiException,
amdsmi_get_processor_handle_from_bdf, amdsmi_init,
amdsmi_shut_down, amdsmi_topo_get_link_type)
from amdsmi import (AmdSmiException, amdsmi_get_processor_handles,
amdsmi_init, amdsmi_shut_down,
amdsmi_topo_get_link_type)
else:
import pynvml

Expand Down Expand Up @@ -62,25 +62,22 @@ def _is_full_nvlink(device_ids: List[int], world_size) -> bool:
so it works on real physical device ids.
"""
if is_hip():
# get devices' BDF in order to get XGMI link info from amdsmi
bdf = custom_ar.get_device_bdf(torch.cuda.current_device())
all_bdf = [0] * world_size
dist.all_gather_object(all_bdf, bdf)
hsmi = [None] * world_size
try:
for i in range(world_size):
bdf_str = str(bytes(all_bdf[i]).decode("utf-8"))
hsmi[i] = amdsmi_get_processor_handle_from_bdf(bdf_str)
for i in range(world_size):
if i != 0:
link_type = amdsmi_topo_get_link_type(hsmi[0], hsmi[i])
# type is 2 for XGMI
if link_type['hops'] != 1 or link_type['type'] != 2:
# On ROCm, we instead query if GPUs are connected by 1-hop XGMI
handles = [amdsmi_get_processor_handles()[i] for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(
handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error(
"AMD link detection failed.",
exc_info=error)
return False
except AmdSmiException as e:
logger.warning(e)
return False
return True
else:
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
Expand Down

0 comments on commit 3e7b0b6

Please sign in to comment.