Skip to content

Commit

Permalink
refactor: remove C shared memory shim
Browse files Browse the repository at this point in the history
  • Loading branch information
GuanLuo committed Oct 23, 2024
1 parent f791cd4 commit 66f3c75
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 465 deletions.
1 change: 0 additions & 1 deletion src/python/library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ add_custom_target(
if (NOT WIN32)
# Can generate linux specific wheel file on linux systems only.
set(LINUX_WHEEL_DEPENDS
cshm
${WHEEL_DEPENDS}
)

Expand Down
4 changes: 0 additions & 4 deletions src/python/library/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ def sed(pattern, replace, source, dest=None):
"tritonclient/utils/shared_memory",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory"),
)
shutil.copyfile(
"tritonclient/utils/libcshm.so",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/libcshm.so"),
)
cpdir(
"tritonclient/utils/cuda_shared_memory",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
Expand Down
2 changes: 0 additions & 2 deletions src/python/library/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def req_file(filename, folder="requirements"):
extras_require["all"] = list(chain(extras_require.values()))

platform_package_data = []
if PLATFORM_FLAG != "any":
platform_package_data += ["libcshm.so"]

data_files = [
("", ["LICENSE.txt"]),
Expand Down
4 changes: 2 additions & 2 deletions src/python/library/tests/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_lifecycle(self):
def test_invalid_create_shm(self):
# Raises error since tried to create invalid system shared memory region
with self.assertRaisesRegex(
shm.SharedMemoryException, "unable to initialize the size"
shm.SharedMemoryException, "unable to create the shared memory region"
):
self.shm_handles.append(
shm.create_shared_memory_region("dummy_data", "/dummy_data", -1)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_duplicate_key(self):
)
with self.assertRaisesRegex(
shm.SharedMemoryException,
"unable to create the shared memory region, already exists",
"unable to create the shared memory region",
):
self.shm_handles.append(
shm.create_shared_memory_region(
Expand Down
14 changes: 0 additions & 14 deletions src/python/library/tritonclient/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,6 @@ configure_file(__init__.py __init__.py COPYONLY)
configure_file(_dlpack.py _dlpack.py COPYONLY)
configure_file(_shared_memory_tensor.py _shared_memory_tensor.py COPYONLY)

if(NOT WIN32)
file(COPY shared_memory DESTINATION .)

#
# libcshm.so
#
add_library(cshm SHARED shared_memory/shared_memory.cc)
if(${TRITON_ENABLE_GPU})
target_compile_definitions(cshm PUBLIC TRITON_ENABLE_GPU=1)
target_link_libraries(cshm PUBLIC CUDA::cudart)
endif() # TRITON_ENABLE_GPU
target_link_libraries(cshm PRIVATE rt)
endif() # WIN32

if(NOT WIN32)
configure_file(shared_memory/__init__.py shared_memory/__init__.py COPYONLY)
configure_file(cuda_shared_memory/__init__.py cuda_shared_memory/__init__.py COPYONLY)
Expand Down
257 changes: 64 additions & 193 deletions src/python/library/tritonclient/utils/shared_memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,67 +29,11 @@
import os
import struct
import warnings
from ctypes import *
from multiprocessing import shared_memory as mpshm

import numpy as np
import pkg_resources


class _utf8(object):
@classmethod
def from_param(cls, value):
if value is None:
return None
elif isinstance(value, bytes):
return value
else:
return value.encode("utf8")


_cshm_lib = "cshm" if os.name == "nt" else "libcshm.so"
_cshm_path = pkg_resources.resource_filename(
"tritonclient.utils.shared_memory", _cshm_lib
)
_cshm = cdll.LoadLibrary(_cshm_path)

_cshm_shared_memory_region_create = _cshm.SharedMemoryRegionCreate
_cshm_shared_memory_region_create.restype = c_int
_cshm_shared_memory_region_create.argtypes = [_utf8, _utf8, c_uint64, POINTER(c_void_p)]
_cshm_shared_memory_region_set = _cshm.SharedMemoryRegionSet
_cshm_shared_memory_region_set.restype = c_int
_cshm_shared_memory_region_set.argtypes = [c_void_p, c_uint64, c_uint64, c_void_p]
_cshm_get_shared_memory_handle_info = _cshm.GetSharedMemoryHandleInfo
_cshm_get_shared_memory_handle_info.restype = c_int
_cshm_get_shared_memory_handle_info.argtypes = [
c_void_p,
POINTER(c_char_p),
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_uint64),
POINTER(c_uint64),
]
_cshm_shared_memory_region_destroy = _cshm.SharedMemoryRegionDestroy
_cshm_shared_memory_region_destroy.restype = c_int
_cshm_shared_memory_region_destroy.argtypes = [c_void_p]

mapped_shm_regions = []
_key_mapping = {}


def _raise_if_error(errno):
"""
Raise SharedMemoryException if 'err' is non-success.
Otherwise return nothing.
"""
if errno.value != 0:
ex = SharedMemoryException(errno)
raise ex
return


def _raise_error(msg):
ex = SharedMemoryException(msg)
raise ex
_key_mapping = {}


class SharedMemoryRegion:
Expand All @@ -100,7 +44,7 @@ def __init__(
) -> None:
self._triton_shm_name = triton_shm_name
self._shm_key = shm_key
self._c_handle = c_void_p()
self._mpsm_handle = None


def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only=False):
Expand Down Expand Up @@ -130,49 +74,34 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
SharedMemoryException
If unable to create the shared memory region.
"""

if create_only and shm_key in mapped_shm_regions:
raise SharedMemoryException(
"unable to create the shared memory region, already exists"
)

shm_handle = SharedMemoryRegion(triton_shm_name, shm_key)
# Has been created
if shm_key in _key_mapping:
shm_handle._c_handle = _key_mapping[shm_key][0]
_key_mapping[shm_key][1] += 1
# check on the size
shm_fd = c_int()
region_offset = c_uint64()
shm_byte_size = c_uint64()
shm_addr = c_char_p()
c_shm_key = c_char_p()
_raise_if_error(
c_int(
_cshm_get_shared_memory_handle_info(
shm_handle._c_handle,
byref(shm_addr),
byref(c_shm_key),
byref(shm_fd),
byref(region_offset),
byref(shm_byte_size),
)
)
)
if byte_size > shm_byte_size.value:
warnings.warn(
f"reusing shared memory region with key '{shm_key}', region size is {shm_byte_size.value} instead of requested {byte_size}"
)
else:
_raise_if_error(
c_int(
_cshm_shared_memory_region_create(
triton_shm_name, shm_key, byte_size, byref(shm_handle._c_handle)
)
# Check whether the region exists before creating it
if not create_only:
try:
shm_handle._mpsm_handle = mpshm.SharedMemory(shm_key)
if shm_key not in _key_mapping:
_key_mapping[shm_key] = [False, 0]
_key_mapping[shm_key][1] += 1
except FileNotFoundError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
if shm_handle._mpsm_handle is None:
try:
shm_handle._mpsm_handle = mpshm.SharedMemory(
shm_key, create=True, size=byte_size
)
except Exception as ex:
raise SharedMemoryException(
"unable to create the shared memory region"
) from ex
if shm_key not in _key_mapping:
_key_mapping[shm_key] = [False, 0]
_key_mapping[shm_key][0] = True
_key_mapping[shm_key][1] += 1

if byte_size > shm_handle._mpsm_handle.size:
warnings.warn(
f"reusing shared memory region with key '{shm_key}', region size is {shm_handle._mpsm_handle.size} instead of requested {byte_size}"
)
_key_mapping[shm_key] = [shm_handle._c_handle, 1]
mapped_shm_regions.append(shm_key)

return shm_handle

Expand All @@ -197,41 +126,33 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
"""

if not isinstance(input_values, (list, tuple)):
_raise_error("input_values must be specified as a list/tuple of numpy arrays")
raise SharedMemoryException(
"input_values must be specified as a list/tuple of numpy arrays"
)
for input_value in input_values:
if not isinstance(input_value, np.ndarray):
_raise_error("each element of input_values must be a numpy array")
raise SharedMemoryException(
"each element of input_values must be a numpy array"
)

offset_current = offset
for input_value in input_values:
input_value = np.ascontiguousarray(input_value).flatten()
if input_value.dtype == np.object_:
input_value = input_value.item()
byte_size = np.dtype(np.byte).itemsize * len(input_value)
_raise_if_error(
c_int(
_cshm_shared_memory_region_set(
shm_handle._c_handle,
c_uint64(offset_current),
c_uint64(byte_size),
cast(input_value, c_void_p),
)
try:
for input_value in input_values:
if input_value.dtype == np.object_:
byte_size = len(input_value.item())
shm_handle._mpsm_handle.buf[offset : offset + byte_size] = (
input_value.item()
)
)
else:
byte_size = input_value.size * input_value.itemsize
_raise_if_error(
c_int(
_cshm_shared_memory_region_set(
shm_handle._c_handle,
c_uint64(offset_current),
c_uint64(byte_size),
input_value.ctypes.data_as(c_void_p),
)
offset += byte_size
else:
shm_tensor_view = np.ndarray(
input_value.shape,
input_value.dtype,
buffer=shm_handle._mpsm_handle.buf[offset:],
)
)
offset_current += byte_size
return
shm_tensor_view[:] = input_value[:]
offset += input_value.nbytes
except Exception as ex:
raise SharedMemoryException("unable to set the shared memory region") from ex


def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
Expand All @@ -256,42 +177,13 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
The numpy array generated using the contents of the specified shared
memory region.
"""
shm_fd = c_int()
region_offset = c_uint64()
byte_size = c_uint64()
shm_addr = c_char_p()
shm_key = c_char_p()
_raise_if_error(
c_int(
_cshm_get_shared_memory_handle_info(
shm_handle._c_handle,
byref(shm_addr),
byref(shm_key),
byref(shm_fd),
byref(region_offset),
byref(byte_size),
)
)
)
start_pos = region_offset.value + offset
if (datatype != np.object_) and (datatype != np.bytes_):
requested_byte_size = np.prod(shape) * np.dtype(datatype).itemsize
cval_len = start_pos + requested_byte_size
if byte_size.value < cval_len:
_raise_error(
"The size of the shared memory region is insufficient to provide numpy array with requested size"
)
if cval_len == 0:
result = np.empty(shape, dtype=datatype)
else:
val_buf = cast(shm_addr, POINTER(c_byte * cval_len))[0]
val = np.frombuffer(val_buf, dtype=datatype, offset=start_pos)

# Reshape the result to the appropriate shape.
result = np.reshape(val, shape)
result = np.ndarray(
shape, datatype, buffer=shm_handle._mpsm_handle.buf[offset:]
)
else:
str_offset = start_pos
val_buf = cast(shm_addr, POINTER(c_byte * byte_size.value))[0]
str_offset = offset
val_buf = shm_handle._mpsm_handle.buf
ii = 0
strs = list()
while (ii % np.prod(shape) != 0) or (ii == 0):
Expand Down Expand Up @@ -319,7 +211,7 @@ def mapped_shared_memory_regions():
The list of mapped system shared memory regions.
"""

return mapped_shm_regions
return list(_key_mapping.keys())


def destroy_shared_memory_region(shm_handle):
Expand All @@ -341,38 +233,17 @@ def destroy_shared_memory_region(shm_handle):
# fail, a re-attempt could result in a segfault. Secondarily, if we
# fail to delete a region, we should not report it back to the user
# as a valid memory region.
shm_handle._mpsm_handle.close()
_key_mapping[shm_handle._shm_key][1] -= 1
if _key_mapping[shm_handle._shm_key][1] == 0:
mapped_shm_regions.remove(shm_handle._shm_key)
_key_mapping.pop(shm_handle._shm_key)
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle._c_handle)))
try:
if _key_mapping[shm_handle._shm_key][0]:
shm_handle._mpsm_handle.unlink()
finally:
_key_mapping.pop(shm_handle._shm_key)


class SharedMemoryException(Exception):
"""Exception indicating non-Success status.
Parameters
----------
err : c_void_p
Pointer to an Error that should be used to initialize the exception.
"""
"""Exception type for shared memory related error."""

def __init__(self, err):
self.err_code_map = {
-2: "unable to get shared memory descriptor",
-3: "unable to initialize the size",
-4: "unable to read/mmap the shared memory region",
-5: "unable to unlink the shared memory region",
-6: "unable to munmap the shared memory region",
-7: "unable to set the shared memory region",
}
self._msg = None
if type(err) == str:
self._msg = err
elif err.value != 0 and err.value in self.err_code_map:
self._msg = self.err_code_map[err.value]

def __str__(self):
msg = super().__str__() if self._msg is None else self._msg
return msg
pass
Loading

0 comments on commit 66f3c75

Please sign in to comment.