Skip to content

Commit 4d7a0b8

Browse files
committed
test: add unit test for shared memory
1 parent 519124f commit 4d7a0b8

File tree

3 files changed

+212
-9
lines changed

3 files changed

+212
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import unittest
28+
29+
import numpy
30+
import tritonclient.utils as utils
31+
import tritonclient.utils.shared_memory as shm
32+
33+
34+
class SharedMemoryTest(unittest.TestCase):
35+
"""
36+
Testing shared memory utilities
37+
"""
38+
39+
def setUp(self):
40+
self.shm_handles = []
41+
42+
def tearDown(self):
43+
for shm_handle in self.shm_handles:
44+
# [NOTE] wrapper for old implementation that will fail
45+
try:
46+
shm.destroy_shared_memory_region(shm_handle)
47+
except shm.SharedMemoryException as ex:
48+
if "unlink" in str(ex):
49+
pass
50+
else:
51+
raise ex
52+
53+
def test_lifecycle(self):
54+
cpu_tensor = numpy.ones([4, 4], dtype=numpy.float32)
55+
byte_size = 64
56+
self.shm_handles.append(
57+
shm.create_shared_memory_region("shm_name", "shm_key", byte_size)
58+
)
59+
60+
self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)
61+
62+
# Set data from Numpy array
63+
shm.set_shared_memory_region(self.shm_handles[0], [cpu_tensor])
64+
shm_tensor = shm.get_contents_as_numpy(
65+
self.shm_handles[0], numpy.float32, [4, 4]
66+
)
67+
68+
self.assertTrue(numpy.allclose(cpu_tensor, shm_tensor))
69+
70+
shm.destroy_shared_memory_region(self.shm_handles.pop(0))
71+
72+
def test_set_region_offset(self):
73+
large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
74+
large_size = 64
75+
self.shm_handles.append(
76+
shm.create_shared_memory_region("shm_name", "shm_key", large_size)
77+
)
78+
shm.set_shared_memory_region(self.shm_handles[0], [large_tensor])
79+
small_tensor = numpy.zeros([2, 4], dtype=numpy.float32)
80+
small_size = 32
81+
shm.set_shared_memory_region(
82+
self.shm_handles[0], [small_tensor], offset=large_size - small_size
83+
)
84+
shm_tensor = shm.get_contents_as_numpy(
85+
self.shm_handles[0], numpy.float32, [2, 4], offset=large_size - small_size
86+
)
87+
88+
self.assertTrue(numpy.allclose(small_tensor, shm_tensor))
89+
90+
# [NOTE] current impl will fail
91+
def test_set_region_oversize(self):
92+
large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
93+
small_size = 32
94+
self.shm_handles.append(
95+
shm.create_shared_memory_region("shm_name", "shm_key", small_size)
96+
)
97+
with self.assertRaises(shm.SharedMemoryException):
98+
shm.set_shared_memory_region(self.shm_handles[0], [large_tensor])
99+
100+
def test_duplicate_key(self):
101+
# [NOTE] change in behavior:
102+
# previous: okay to create shared memory region of the same key with different size
103+
# and the behavior is not being study clearly.
104+
# now: only allow create by default, flag may be set to return the same handle if
105+
# existed, warning will be print if size is different
106+
self.shm_handles.append(
107+
shm.create_shared_memory_region("shm_name", "shm_key", 32)
108+
)
109+
with self.assertRaises(shm.SharedMemoryException):
110+
self.shm_handles.append(
111+
shm.create_shared_memory_region("shm_name", "shm_key", 32)
112+
)
113+
114+
# Get handle to the same shared memory region but with larger size requested,
115+
# check if actual size is checked
116+
self.shm_handles.append(
117+
shm.create_shared_memory_region("shm_name", "shm_key", 64, create=False)
118+
)
119+
120+
self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)
121+
122+
large_tensor = numpy.ones([4, 4], dtype=numpy.float32)
123+
small_size = 32
124+
# [NOTE] current impl will fail
125+
with self.assertRaises(shm.SharedMemoryException):
126+
shm.set_shared_memory_region(self.shm_handles[-1], [large_tensor])
127+
128+
# [NOTE] current impl will fail
129+
def test_destroy_duplicate(self):
130+
# [NOTE] change in behavior:
131+
# previous: raise exception if underlying shared memory has been unlinked
132+
# now: the exception will be suppressed to align with Windows behavior, unless
133+
# explicitly toggled
134+
self.shm_handles.append(
135+
shm.create_shared_memory_region("shm_name", "shm_key", 64)
136+
)
137+
self.shm_handles.append(
138+
shm.create_shared_memory_region("shm_name", "shm_key", 32, create=False)
139+
)
140+
self.shm_handles.append(
141+
shm.create_shared_memory_region("shm_name", "shm_key", 32, create=False)
142+
)
143+
self.assertEqual(len(shm.mapped_shared_memory_regions()), 1)
144+
145+
shm.destroy_shared_memory_region(self.shm_handles.pop(0))
146+
self.assertEqual(len(shm.mapped_shared_memory_regions()), 0)
147+
148+
shm.destroy_shared_memory_region(self.shm_handles.pop(0))
149+
with self.assertRaises(shm.SharedMemoryException):
150+
shm.destroy_shared_memory_region(
151+
self.shm_handles.pop(0), raise_unlink_exception=True
152+
)
153+
154+
def test_numpy_bytes(self):
155+
int_tensor = numpy.arange(start=0, stop=16, dtype=numpy.int32)
156+
bytes_tensor = numpy.array(
157+
[str(x).encode("utf-8") for x in int_tensor.flatten()], dtype=object
158+
)
159+
bytes_tensor = bytes_tensor.reshape(int_tensor.shape)
160+
bytes_tensor_serialized = utils.serialize_byte_tensor(bytes_tensor)
161+
byte_size = utils.serialized_byte_size(bytes_tensor_serialized)
162+
163+
self.shm_handles.append(
164+
shm.create_shared_memory_region("shm_name", "shm_key", byte_size)
165+
)
166+
167+
# Set data from Numpy array
168+
shm.set_shared_memory_region(self.shm_handles[0], [bytes_tensor_serialized])
169+
170+
shm_tensor = shm.get_contents_as_numpy(
171+
self.shm_handles[0],
172+
numpy.object_,
173+
[
174+
16,
175+
],
176+
)
177+
178+
self.assertTrue(numpy.array_equal(bytes_tensor, shm_tensor))
179+
180+
181+
if __name__ == "__main__":
182+
unittest.main()

src/python/library/tritonclient/utils/shared_memory/__init__.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _raise_error(msg):
9090
raise ex
9191

9292

93-
def create_shared_memory_region(triton_shm_name, shm_key, byte_size):
93+
def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create=True):
9494
"""Creates a system shared memory region with the specified name and size.
9595
9696
Parameters
@@ -113,6 +113,11 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size):
113113
If unable to create the shared memory region.
114114
"""
115115

116+
if create and shm_key in mapped_shm_regions:
117+
raise SharedMemoryException(
118+
"unable to create the shared memory region, already exists"
119+
)
120+
116121
shm_handle = c_void_p()
117122
_raise_if_error(
118123
c_int(
@@ -121,7 +126,9 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size):
121126
)
122127
)
123128
)
124-
mapped_shm_regions.append(shm_key)
129+
130+
if create:
131+
mapped_shm_regions.append(shm_key)
125132

126133
return shm_handle
127134

@@ -271,7 +278,7 @@ def mapped_shared_memory_regions():
271278
return mapped_shm_regions
272279

273280

274-
def destroy_shared_memory_region(shm_handle):
281+
def destroy_shared_memory_region(shm_handle, raise_unlink_exception=False):
275282
"""Unlink a system shared memory region with the specified handle.
276283
277284
Parameters
@@ -306,8 +313,20 @@ def destroy_shared_memory_region(shm_handle):
306313
# fail, a re-attempt could result in a segfault. Secondarily, if we
307314
# fail to delete a region, we should not report it back to the user
308315
# as a valid memory region.
309-
mapped_shm_regions.remove(shm_key.value.decode("utf-8"))
310-
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle)))
316+
try:
317+
mapped_shm_regions.remove(shm_key.value.decode("utf-8"))
318+
except ValueError:
319+
# okay if mapped_shm_regions doesn't have the key as there can be
320+
# destroy call on handles with the same shared memory key
321+
pass
322+
try:
323+
_raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle)))
324+
except SharedMemoryException as ex:
325+
# Suppress unlink exception except when explicitly allow to raise
326+
if not raise_unlink_exception and "unlink" in str(ex):
327+
pass
328+
else:
329+
raise ex
311330
return
312331

313332

@@ -328,6 +347,7 @@ def __init__(self, err):
328347
-4: "unable to read/mmap the shared memory region",
329348
-5: "unable to unlink the shared memory region",
330349
-6: "unable to munmap the shared memory region",
350+
-7: "unable to set the shared memory region",
331351
}
332352
self._msg = None
333353
if type(err) == str:

src/python/library/tritonclient/utils/shared_memory/shared_memory.cc

+5-4
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ int
108108
SharedMemoryRegionSet(
109109
void* shm_handle, size_t offset, size_t byte_size, const void* data)
110110
{
111-
void* shm_addr =
112-
reinterpret_cast<SharedMemoryHandle*>(shm_handle)->base_addr_;
113-
char* shm_addr_offset = reinterpret_cast<char*>(shm_addr);
114-
std::memcpy(shm_addr_offset + offset, data, byte_size);
111+
auto shm = reinterpret_cast<SharedMemoryHandle*>(shm_handle);
112+
if (shm->byte_size_ < (offset + byte_size)) {
113+
return -7;
114+
}
115+
std::memcpy(shm->base_addr_ + offset, data, byte_size);
115116
return 0;
116117
}
117118

0 commit comments

Comments
 (0)