Skip to content

Commit 66f3c75

Browse files
committed
refactor: remove C shared memory shim
1 parent f791cd4 commit 66f3c75

File tree

9 files changed

+66
-465
lines changed

9 files changed

+66
-465
lines changed

src/python/library/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ add_custom_target(
9696
if (NOT WIN32)
9797
# Can generate linux specific wheel file on linux systems only.
9898
set(LINUX_WHEEL_DEPENDS
99-
cshm
10099
${WHEEL_DEPENDS}
101100
)
102101

src/python/library/build_wheel.py

-4
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,6 @@ def sed(pattern, replace, source, dest=None):
174174
"tritonclient/utils/shared_memory",
175175
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory"),
176176
)
177-
shutil.copyfile(
178-
"tritonclient/utils/libcshm.so",
179-
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/libcshm.so"),
180-
)
181177
cpdir(
182178
"tritonclient/utils/cuda_shared_memory",
183179
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),

src/python/library/setup.py

-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def req_file(filename, folder="requirements"):
7676
extras_require["all"] = list(chain(extras_require.values()))
7777

7878
platform_package_data = []
79-
if PLATFORM_FLAG != "any":
80-
platform_package_data += ["libcshm.so"]
8179

8280
data_files = [
8381
("", ["LICENSE.txt"]),

src/python/library/tests/test_shared_memory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_lifecycle(self):
6565
def test_invalid_create_shm(self):
6666
# Raises error since tried to create invalid system shared memory region
6767
with self.assertRaisesRegex(
68-
shm.SharedMemoryException, "unable to initialize the size"
68+
shm.SharedMemoryException, "unable to create the shared memory region"
6969
):
7070
self.shm_handles.append(
7171
shm.create_shared_memory_region("dummy_data", "/dummy_data", -1)
@@ -110,7 +110,7 @@ def test_duplicate_key(self):
110110
)
111111
with self.assertRaisesRegex(
112112
shm.SharedMemoryException,
113-
"unable to create the shared memory region, already exists",
113+
"unable to create the shared memory region",
114114
):
115115
self.shm_handles.append(
116116
shm.create_shared_memory_region(

src/python/library/tritonclient/utils/CMakeLists.txt

-14
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,6 @@ configure_file(__init__.py __init__.py COPYONLY)
2828
configure_file(_dlpack.py _dlpack.py COPYONLY)
2929
configure_file(_shared_memory_tensor.py _shared_memory_tensor.py COPYONLY)
3030

31-
if(NOT WIN32)
32-
file(COPY shared_memory DESTINATION .)
33-
34-
#
35-
# libcshm.so
36-
#
37-
add_library(cshm SHARED shared_memory/shared_memory.cc)
38-
if(${TRITON_ENABLE_GPU})
39-
target_compile_definitions(cshm PUBLIC TRITON_ENABLE_GPU=1)
40-
target_link_libraries(cshm PUBLIC CUDA::cudart)
41-
endif() # TRITON_ENABLE_GPU
42-
target_link_libraries(cshm PRIVATE rt)
43-
endif() # WIN32
44-
4531
if(NOT WIN32)
4632
configure_file(shared_memory/__init__.py shared_memory/__init__.py COPYONLY)
4733
configure_file(cuda_shared_memory/__init__.py cuda_shared_memory/__init__.py COPYONLY)

src/python/library/tritonclient/utils/shared_memory/__init__.py

+64-193
Original file line numberDiff line numberDiff line change
@@ -29,67 +29,11 @@
2929
import os
3030
import struct
3131
import warnings
32-
from ctypes import *
32+
from multiprocessing import shared_memory as mpshm
3333

3434
import numpy as np
35-
import pkg_resources
36-
37-
38-
class _utf8(object):
39-
@classmethod
40-
def from_param(cls, value):
41-
if value is None:
42-
return None
43-
elif isinstance(value, bytes):
44-
return value
45-
else:
46-
return value.encode("utf8")
47-
48-
49-
_cshm_lib = "cshm" if os.name == "nt" else "libcshm.so"
50-
_cshm_path = pkg_resources.resource_filename(
51-
"tritonclient.utils.shared_memory", _cshm_lib
52-
)
53-
_cshm = cdll.LoadLibrary(_cshm_path)
54-
55-
_cshm_shared_memory_region_create = _cshm.SharedMemoryRegionCreate
56-
_cshm_shared_memory_region_create.restype = c_int
57-
_cshm_shared_memory_region_create.argtypes = [_utf8, _utf8, c_uint64, POINTER(c_void_p)]
58-
_cshm_shared_memory_region_set = _cshm.SharedMemoryRegionSet
59-
_cshm_shared_memory_region_set.restype = c_int
60-
_cshm_shared_memory_region_set.argtypes = [c_void_p, c_uint64, c_uint64, c_void_p]
61-
_cshm_get_shared_memory_handle_info = _cshm.GetSharedMemoryHandleInfo
62-
_cshm_get_shared_memory_handle_info.restype = c_int
63-
_cshm_get_shared_memory_handle_info.argtypes = [
64-
c_void_p,
65-
POINTER(c_char_p),
66-
POINTER(c_char_p),
67-
POINTER(c_int),
68-
POINTER(c_uint64),
69-
POINTER(c_uint64),
70-
]
71-
_cshm_shared_memory_region_destroy = _cshm.SharedMemoryRegionDestroy
72-
_cshm_shared_memory_region_destroy.restype = c_int
73-
_cshm_shared_memory_region_destroy.argtypes = [c_void_p]
74-
75-
mapped_shm_regions = []
76-
_key_mapping = {}
77-
78-
79-
def _raise_if_error(errno):
80-
"""
81-
Raise SharedMemoryException if 'err' is non-success.
82-
Otherwise return nothing.
83-
"""
84-
if errno.value != 0:
85-
ex = SharedMemoryException(errno)
86-
raise ex
87-
return
8835

89-
90-
def _raise_error(msg):
91-
ex = SharedMemoryException(msg)
92-
raise ex
36+
_key_mapping = {}
9337

9438

9539
class SharedMemoryRegion:
@@ -100,7 +44,7 @@ def __init__(
10044
) -> None:
10145
self._triton_shm_name = triton_shm_name
10246
self._shm_key = shm_key
103-
self._c_handle = c_void_p()
47+
self._mpsm_handle = None
10448

10549

10650
def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only=False):
@@ -130,49 +74,34 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
13074
SharedMemoryException
13175
If unable to create the shared memory region.
13276
"""
133-
134-
if create_only and shm_key in mapped_shm_regions:
135-
raise SharedMemoryException(
136-
"unable to create the shared memory region, already exists"
137-
)
138-
13977
shm_handle = SharedMemoryRegion(triton_shm_name, shm_key)
140-
# Has been created
141-
if shm_key in _key_mapping:
142-
shm_handle._c_handle = _key_mapping[shm_key][0]
143-
_key_mapping[shm_key][1] += 1
144-
# check on the size
145-
shm_fd = c_int()
146-
region_offset = c_uint64()
147-
shm_byte_size = c_uint64()
148-
shm_addr = c_char_p()
149-
c_shm_key = c_char_p()
150-
_raise_if_error(
151-
c_int(
152-
_cshm_get_shared_memory_handle_info(
153-
shm_handle._c_handle,
154-
byref(shm_addr),
155-
byref(c_shm_key),
156-
byref(shm_fd),
157-
byref(region_offset),
158-
byref(shm_byte_size),
159-
)
160-
)
161-
)
162-
if byte_size > shm_byte_size.value:
163-
warnings.warn(
164-
f"reusing shared memory region with key '{shm_key}', region size is {shm_byte_size.value} instead of requested {byte_size}"
165-
)
166-
else:
167-
_raise_if_error(
168-
c_int(
169-
_cshm_shared_memory_region_create(
170-
triton_shm_name, shm_key, byte_size, byref(shm_handle._c_handle)
171-
)
78+
# Check whether the region exists before creating it
79+
if not create_only:
80+
try:
81+
shm_handle._mpsm_handle = mpshm.SharedMemory(shm_key)
82+
if shm_key not in _key_mapping:
83+
_key_mapping[shm_key] = [False, 0]
84+
_key_mapping[shm_key][1] += 1
85+
except FileNotFoundError:
86+
pass
87+
if shm_handle._mpsm_handle is None:
88+
try:
89+
shm_handle._mpsm_handle = mpshm.SharedMemory(
90+
shm_key, create=True, size=byte_size
17291
)
92+
except Exception as ex:
93+
raise SharedMemoryException(
94+
"unable to create the shared memory region"
95+
) from ex
96+
if shm_key not in _key_mapping:
97+
_key_mapping[shm_key] = [False, 0]
98+
_key_mapping[shm_key][0] = True
99+
_key_mapping[shm_key][1] += 1
100+
101+
if byte_size > shm_handle._mpsm_handle.size:
102+
warnings.warn(
103+
f"reusing shared memory region with key '{shm_key}', region size is {shm_handle._mpsm_handle.size} instead of requested {byte_size}"
173104
)
174-
_key_mapping[shm_key] = [shm_handle._c_handle, 1]
175-
mapped_shm_regions.append(shm_key)
176105

177106
return shm_handle
178107

@@ -197,41 +126,33 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
197126
"""
198127

199128
if not isinstance(input_values, (list, tuple)):
200-
_raise_error("input_values must be specified as a list/tuple of numpy arrays")
129+
raise SharedMemoryException(
130+
"input_values must be specified as a list/tuple of numpy arrays"
131+
)
201132
for input_value in input_values:
202133
if not isinstance(input_value, np.ndarray):
203-
_raise_error("each element of input_values must be a numpy array")
134+
raise SharedMemoryException(
135+
"each element of input_values must be a numpy array"
136+
)
204137

205-
offset_current = offset
206-
for input_value in input_values:
207-
input_value = np.ascontiguousarray(input_value).flatten()
208-
if input_value.dtype == np.object_:
209-
input_value = input_value.item()
210-
byte_size = np.dtype(np.byte).itemsize * len(input_value)
211-
_raise_if_error(
212-
c_int(
213-
_cshm_shared_memory_region_set(
214-
shm_handle._c_handle,
215-
c_uint64(offset_current),
216-
c_uint64(byte_size),
217-
cast(input_value, c_void_p),
218-
)
138+
try:
139+
for input_value in input_values:
140+
if input_value.dtype == np.object_:
141+
byte_size = len(input_value.item())
142+
shm_handle._mpsm_handle.buf[offset : offset + byte_size] = (
143+
input_value.item()
219144
)
220-
)
221-
else:
222-
byte_size = input_value.size * input_value.itemsize
223-
_raise_if_error(
224-
c_int(
225-
_cshm_shared_memory_region_set(
226-
shm_handle._c_handle,
227-
c_uint64(offset_current),
228-
c_uint64(byte_size),
229-
input_value.ctypes.data_as(c_void_p),
230-
)
145+
offset += byte_size
146+
else:
147+
shm_tensor_view = np.ndarray(
148+
input_value.shape,
149+
input_value.dtype,
150+
buffer=shm_handle._mpsm_handle.buf[offset:],
231151
)
232-
)
233-
offset_current += byte_size
234-
return
152+
shm_tensor_view[:] = input_value[:]
153+
offset += input_value.nbytes
154+
except Exception as ex:
155+
raise SharedMemoryException("unable to set the shared memory region") from ex
235156

236157

237158
def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
@@ -256,42 +177,13 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
256177
The numpy array generated using the contents of the specified shared
257178
memory region.
258179
"""
259-
shm_fd = c_int()
260-
region_offset = c_uint64()
261-
byte_size = c_uint64()
262-
shm_addr = c_char_p()
263-
shm_key = c_char_p()
264-
_raise_if_error(
265-
c_int(
266-
_cshm_get_shared_memory_handle_info(
267-
shm_handle._c_handle,
268-
byref(shm_addr),
269-
byref(shm_key),
270-
byref(shm_fd),
271-
byref(region_offset),
272-
byref(byte_size),
273-
)
274-
)
275-
)
276-
start_pos = region_offset.value + offset
277180
if (datatype != np.object_) and (datatype != np.bytes_):
278-
requested_byte_size = np.prod(shape) * np.dtype(datatype).itemsize
279-
cval_len = start_pos + requested_byte_size
280-
if byte_size.value < cval_len:
281-
_raise_error(
282-
"The size of the shared memory region is insufficient to provide numpy array with requested size"
283-
)
284-
if cval_len == 0:
285-
result = np.empty(shape, dtype=datatype)
286-
else:
287-
val_buf = cast(shm_addr, POINTER(c_byte * cval_len))[0]
288-
val = np.frombuffer(val_buf, dtype=datatype, offset=start_pos)
289-
290-
# Reshape the result to the appropriate shape.
291-
result = np.reshape(val, shape)
181+
result = np.ndarray(
182+
shape, datatype, buffer=shm_handle._mpsm_handle.buf[offset:]
183+
)
292184
else:
293-
str_offset = start_pos
294-
val_buf = cast(shm_addr, POINTER(c_byte * byte_size.value))[0]
185+
str_offset = offset
186+
val_buf = shm_handle._mpsm_handle.buf
295187
ii = 0
296188
strs = list()
297189
while (ii % np.prod(shape) != 0) or (ii == 0):
@@ -319,7 +211,7 @@ def mapped_shared_memory_regions():
319211
The list of mapped system shared memory regions.
320212
"""
321213

322-
return mapped_shm_regions
214+
return list(_key_mapping.keys())
323215

324216

325217
def destroy_shared_memory_region(shm_handle):
@@ -341,38 +233,17 @@ def destroy_shared_memory_region(shm_handle):
341233
# fail, a re-attempt could result in a segfault. Secondarily, if we
342234
# fail to delete a region, we should not report it back to the user
343235
# as a valid memory region.
236+
shm_handle._mpsm_handle.close()
344237
_key_mapping[shm_handle._shm_key][1] -= 1
345238
if _key_mapping[shm_handle._shm_key][1] == 0:
346-
mapped_shm_regions.remove(shm_handle._shm_key)
347-
_key_mapping.pop(shm_handle._shm_key)
348-
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle._c_handle)))
239+
try:
240+
if _key_mapping[shm_handle._shm_key][0]:
241+
shm_handle._mpsm_handle.unlink()
242+
finally:
243+
_key_mapping.pop(shm_handle._shm_key)
349244

350245

351246
class SharedMemoryException(Exception):
352-
"""Exception indicating non-Success status.
353-
354-
Parameters
355-
----------
356-
err : c_void_p
357-
Pointer to an Error that should be used to initialize the exception.
358-
359-
"""
247+
"""Exception type for shared memory related error."""
360248

361-
def __init__(self, err):
362-
self.err_code_map = {
363-
-2: "unable to get shared memory descriptor",
364-
-3: "unable to initialize the size",
365-
-4: "unable to read/mmap the shared memory region",
366-
-5: "unable to unlink the shared memory region",
367-
-6: "unable to munmap the shared memory region",
368-
-7: "unable to set the shared memory region",
369-
}
370-
self._msg = None
371-
if type(err) == str:
372-
self._msg = err
373-
elif err.value != 0 and err.value in self.err_code_map:
374-
self._msg = self.err_code_map[err.value]
375-
376-
def __str__(self):
377-
msg = super().__str__() if self._msg is None else self._msg
378-
return msg
249+
pass

0 commit comments

Comments
 (0)