29
29
import os
30
30
import struct
31
31
import warnings
32
- from ctypes import *
32
+ from multiprocessing import shared_memory as mpshm
33
33
34
34
import numpy as np
35
- import pkg_resources
36
-
37
-
38
- class _utf8 (object ):
39
- @classmethod
40
- def from_param (cls , value ):
41
- if value is None :
42
- return None
43
- elif isinstance (value , bytes ):
44
- return value
45
- else :
46
- return value .encode ("utf8" )
47
-
48
-
49
- _cshm_lib = "cshm" if os .name == "nt" else "libcshm.so"
50
- _cshm_path = pkg_resources .resource_filename (
51
- "tritonclient.utils.shared_memory" , _cshm_lib
52
- )
53
- _cshm = cdll .LoadLibrary (_cshm_path )
54
-
55
- _cshm_shared_memory_region_create = _cshm .SharedMemoryRegionCreate
56
- _cshm_shared_memory_region_create .restype = c_int
57
- _cshm_shared_memory_region_create .argtypes = [_utf8 , _utf8 , c_uint64 , POINTER (c_void_p )]
58
- _cshm_shared_memory_region_set = _cshm .SharedMemoryRegionSet
59
- _cshm_shared_memory_region_set .restype = c_int
60
- _cshm_shared_memory_region_set .argtypes = [c_void_p , c_uint64 , c_uint64 , c_void_p ]
61
- _cshm_get_shared_memory_handle_info = _cshm .GetSharedMemoryHandleInfo
62
- _cshm_get_shared_memory_handle_info .restype = c_int
63
- _cshm_get_shared_memory_handle_info .argtypes = [
64
- c_void_p ,
65
- POINTER (c_char_p ),
66
- POINTER (c_char_p ),
67
- POINTER (c_int ),
68
- POINTER (c_uint64 ),
69
- POINTER (c_uint64 ),
70
- ]
71
- _cshm_shared_memory_region_destroy = _cshm .SharedMemoryRegionDestroy
72
- _cshm_shared_memory_region_destroy .restype = c_int
73
- _cshm_shared_memory_region_destroy .argtypes = [c_void_p ]
74
-
75
- mapped_shm_regions = []
76
- _key_mapping = {}
77
-
78
-
79
- def _raise_if_error (errno ):
80
- """
81
- Raise SharedMemoryException if 'err' is non-success.
82
- Otherwise return nothing.
83
- """
84
- if errno .value != 0 :
85
- ex = SharedMemoryException (errno )
86
- raise ex
87
- return
88
35
89
-
90
- def _raise_error (msg ):
91
- ex = SharedMemoryException (msg )
92
- raise ex
36
+ _key_mapping = {}
93
37
94
38
95
39
class SharedMemoryRegion :
@@ -100,7 +44,7 @@ def __init__(
100
44
) -> None :
101
45
self ._triton_shm_name = triton_shm_name
102
46
self ._shm_key = shm_key
103
- self ._c_handle = c_void_p ()
47
+ self ._mpsm_handle = None
104
48
105
49
106
50
def create_shared_memory_region (triton_shm_name , shm_key , byte_size , create_only = False ):
@@ -130,49 +74,34 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
130
74
SharedMemoryException
131
75
If unable to create the shared memory region.
132
76
"""
133
-
134
- if create_only and shm_key in mapped_shm_regions :
135
- raise SharedMemoryException (
136
- "unable to create the shared memory region, already exists"
137
- )
138
-
139
77
shm_handle = SharedMemoryRegion (triton_shm_name , shm_key )
140
- # Has been created
141
- if shm_key in _key_mapping :
142
- shm_handle ._c_handle = _key_mapping [shm_key ][0 ]
143
- _key_mapping [shm_key ][1 ] += 1
144
- # check on the size
145
- shm_fd = c_int ()
146
- region_offset = c_uint64 ()
147
- shm_byte_size = c_uint64 ()
148
- shm_addr = c_char_p ()
149
- c_shm_key = c_char_p ()
150
- _raise_if_error (
151
- c_int (
152
- _cshm_get_shared_memory_handle_info (
153
- shm_handle ._c_handle ,
154
- byref (shm_addr ),
155
- byref (c_shm_key ),
156
- byref (shm_fd ),
157
- byref (region_offset ),
158
- byref (shm_byte_size ),
159
- )
160
- )
161
- )
162
- if byte_size > shm_byte_size .value :
163
- warnings .warn (
164
- f"reusing shared memory region with key '{ shm_key } ', region size is { shm_byte_size .value } instead of requested { byte_size } "
165
- )
166
- else :
167
- _raise_if_error (
168
- c_int (
169
- _cshm_shared_memory_region_create (
170
- triton_shm_name , shm_key , byte_size , byref (shm_handle ._c_handle )
171
- )
78
+ # Check whether the region exists before creating it
79
+ if not create_only :
80
+ try :
81
+ shm_handle ._mpsm_handle = mpshm .SharedMemory (shm_key )
82
+ if shm_key not in _key_mapping :
83
+ _key_mapping [shm_key ] = [False , 0 ]
84
+ _key_mapping [shm_key ][1 ] += 1
85
+ except FileNotFoundError :
86
+ pass
87
+ if shm_handle ._mpsm_handle is None :
88
+ try :
89
+ shm_handle ._mpsm_handle = mpshm .SharedMemory (
90
+ shm_key , create = True , size = byte_size
172
91
)
92
+ except Exception as ex :
93
+ raise SharedMemoryException (
94
+ "unable to create the shared memory region"
95
+ ) from ex
96
+ if shm_key not in _key_mapping :
97
+ _key_mapping [shm_key ] = [False , 0 ]
98
+ _key_mapping [shm_key ][0 ] = True
99
+ _key_mapping [shm_key ][1 ] += 1
100
+
101
+ if byte_size > shm_handle ._mpsm_handle .size :
102
+ warnings .warn (
103
+ f"reusing shared memory region with key '{ shm_key } ', region size is { shm_handle ._mpsm_handle .size } instead of requested { byte_size } "
173
104
)
174
- _key_mapping [shm_key ] = [shm_handle ._c_handle , 1 ]
175
- mapped_shm_regions .append (shm_key )
176
105
177
106
return shm_handle
178
107
@@ -197,41 +126,33 @@ def set_shared_memory_region(shm_handle, input_values, offset=0):
197
126
"""
198
127
199
128
if not isinstance (input_values , (list , tuple )):
200
- _raise_error ("input_values must be specified as a list/tuple of numpy arrays" )
129
+ raise SharedMemoryException (
130
+ "input_values must be specified as a list/tuple of numpy arrays"
131
+ )
201
132
for input_value in input_values :
202
133
if not isinstance (input_value , np .ndarray ):
203
- _raise_error ("each element of input_values must be a numpy array" )
134
+ raise SharedMemoryException (
135
+ "each element of input_values must be a numpy array"
136
+ )
204
137
205
- offset_current = offset
206
- for input_value in input_values :
207
- input_value = np .ascontiguousarray (input_value ).flatten ()
208
- if input_value .dtype == np .object_ :
209
- input_value = input_value .item ()
210
- byte_size = np .dtype (np .byte ).itemsize * len (input_value )
211
- _raise_if_error (
212
- c_int (
213
- _cshm_shared_memory_region_set (
214
- shm_handle ._c_handle ,
215
- c_uint64 (offset_current ),
216
- c_uint64 (byte_size ),
217
- cast (input_value , c_void_p ),
218
- )
138
+ try :
139
+ for input_value in input_values :
140
+ if input_value .dtype == np .object_ :
141
+ byte_size = len (input_value .item ())
142
+ shm_handle ._mpsm_handle .buf [offset : offset + byte_size ] = (
143
+ input_value .item ()
219
144
)
220
- )
221
- else :
222
- byte_size = input_value .size * input_value .itemsize
223
- _raise_if_error (
224
- c_int (
225
- _cshm_shared_memory_region_set (
226
- shm_handle ._c_handle ,
227
- c_uint64 (offset_current ),
228
- c_uint64 (byte_size ),
229
- input_value .ctypes .data_as (c_void_p ),
230
- )
145
+ offset += byte_size
146
+ else :
147
+ shm_tensor_view = np .ndarray (
148
+ input_value .shape ,
149
+ input_value .dtype ,
150
+ buffer = shm_handle ._mpsm_handle .buf [offset :],
231
151
)
232
- )
233
- offset_current += byte_size
234
- return
152
+ shm_tensor_view [:] = input_value [:]
153
+ offset += input_value .nbytes
154
+ except Exception as ex :
155
+ raise SharedMemoryException ("unable to set the shared memory region" ) from ex
235
156
236
157
237
158
def get_contents_as_numpy (shm_handle , datatype , shape , offset = 0 ):
@@ -256,42 +177,13 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
256
177
The numpy array generated using the contents of the specified shared
257
178
memory region.
258
179
"""
259
- shm_fd = c_int ()
260
- region_offset = c_uint64 ()
261
- byte_size = c_uint64 ()
262
- shm_addr = c_char_p ()
263
- shm_key = c_char_p ()
264
- _raise_if_error (
265
- c_int (
266
- _cshm_get_shared_memory_handle_info (
267
- shm_handle ._c_handle ,
268
- byref (shm_addr ),
269
- byref (shm_key ),
270
- byref (shm_fd ),
271
- byref (region_offset ),
272
- byref (byte_size ),
273
- )
274
- )
275
- )
276
- start_pos = region_offset .value + offset
277
180
if (datatype != np .object_ ) and (datatype != np .bytes_ ):
278
- requested_byte_size = np .prod (shape ) * np .dtype (datatype ).itemsize
279
- cval_len = start_pos + requested_byte_size
280
- if byte_size .value < cval_len :
281
- _raise_error (
282
- "The size of the shared memory region is insufficient to provide numpy array with requested size"
283
- )
284
- if cval_len == 0 :
285
- result = np .empty (shape , dtype = datatype )
286
- else :
287
- val_buf = cast (shm_addr , POINTER (c_byte * cval_len ))[0 ]
288
- val = np .frombuffer (val_buf , dtype = datatype , offset = start_pos )
289
-
290
- # Reshape the result to the appropriate shape.
291
- result = np .reshape (val , shape )
181
+ result = np .ndarray (
182
+ shape , datatype , buffer = shm_handle ._mpsm_handle .buf [offset :]
183
+ )
292
184
else :
293
- str_offset = start_pos
294
- val_buf = cast ( shm_addr , POINTER ( c_byte * byte_size . value ))[ 0 ]
185
+ str_offset = offset
186
+ val_buf = shm_handle . _mpsm_handle . buf
295
187
ii = 0
296
188
strs = list ()
297
189
while (ii % np .prod (shape ) != 0 ) or (ii == 0 ):
@@ -319,7 +211,7 @@ def mapped_shared_memory_regions():
319
211
The list of mapped system shared memory regions.
320
212
"""
321
213
322
- return mapped_shm_regions
214
+ return list ( _key_mapping . keys ())
323
215
324
216
325
217
def destroy_shared_memory_region (shm_handle ):
@@ -341,38 +233,17 @@ def destroy_shared_memory_region(shm_handle):
341
233
# fail, a re-attempt could result in a segfault. Secondarily, if we
342
234
# fail to delete a region, we should not report it back to the user
343
235
# as a valid memory region.
236
+ shm_handle ._mpsm_handle .close ()
344
237
_key_mapping [shm_handle ._shm_key ][1 ] -= 1
345
238
if _key_mapping [shm_handle ._shm_key ][1 ] == 0 :
346
- mapped_shm_regions .remove (shm_handle ._shm_key )
347
- _key_mapping .pop (shm_handle ._shm_key )
348
- _raise_if_error (c_int (_cshm_shared_memory_region_destroy (shm_handle ._c_handle )))
239
+ try :
240
+ if _key_mapping [shm_handle ._shm_key ][0 ]:
241
+ shm_handle ._mpsm_handle .unlink ()
242
+ finally :
243
+ _key_mapping .pop (shm_handle ._shm_key )
349
244
350
245
351
246
class SharedMemoryException (Exception ):
352
- """Exception indicating non-Success status.
353
-
354
- Parameters
355
- ----------
356
- err : c_void_p
357
- Pointer to an Error that should be used to initialize the exception.
358
-
359
- """
247
+ """Exception type for shared memory related error."""
360
248
361
- def __init__ (self , err ):
362
- self .err_code_map = {
363
- - 2 : "unable to get shared memory descriptor" ,
364
- - 3 : "unable to initialize the size" ,
365
- - 4 : "unable to read/mmap the shared memory region" ,
366
- - 5 : "unable to unlink the shared memory region" ,
367
- - 6 : "unable to munmap the shared memory region" ,
368
- - 7 : "unable to set the shared memory region" ,
369
- }
370
- self ._msg = None
371
- if type (err ) == str :
372
- self ._msg = err
373
- elif err .value != 0 and err .value in self .err_code_map :
374
- self ._msg = self .err_code_map [err .value ]
375
-
376
- def __str__ (self ):
377
- msg = super ().__str__ () if self ._msg is None else self ._msg
378
- return msg
249
+ pass
0 commit comments