Skip to content

Commit 452bf8d

Browse files
seanprime7Angelogeb
andcommitted
Update 26.01.07
- 37d4a50ff89b6bce498dd86e57f02f68849ccd9d by Anxhelo Xhebraj <axhebraj@nvidia.com> Co-authored-by: Anxhelo Xhebraj <axhebraj@nvidia.com> Signed-off-by: Sean Lee <selee@nvidia.com> GitOrigin-RevId: 37d4a50ff89b6bce498dd86e57f02f68849ccd9d
1 parent d43e360 commit 452bf8d

File tree

5 files changed

+223
-14
lines changed

5 files changed

+223
-14
lines changed

scripts/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ pytest tests/test_mpmd_array.py
2222
if [ $(nvidia-smi -L | wc -l) -ge 8 ]; then
2323
N_PROCS=2 N_GPUS=4 COMMAND="python -u tests/test_reshard_utils.py" ./scripts/local_mc.sh
2424
N_PROCS=2 N_GPUS=4 COMMAND="python -u examples/mpmd_reshard.py" ./scripts/local_mc.sh
25+
N_PROCS=2 N_GPUS=4 COMMAND="python -u tests/test_dime2.py" ./scripts/local_mc.sh
2526
fi

src/jaxpp/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");

src/jaxpp/dlpack.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -65,26 +65,51 @@ class DLDevice(ctypes.Structure):
6565
("device_id", ctypes.c_int),
6666
]
6767

68-
68+
# https://github.com/dmlc/dlpack/blob/6ea9b3eb64c881f614cd4537f95f0e125a35555c/include/dlpack/dlpack.h#L141-L182
6969
class DLDataTypeCode(ctypes.c_uint8):
7070
"""An integer that encodes the category of DLTensor elements' data type."""
7171
kDLInt = 0
7272
kDLUInt = 1
7373
kDLFloat = 2
74-
kDLOpaquePointer = 3
74+
kDLOpaqueHandle = 3
7575
kDLBfloat = 4
7676
kDLComplex = 5
7777
kDLBool = 6
78+
# FP8 data types
79+
kDLFloat8_e3m4 = 7
80+
kDLFloat8_e4m3 = 8
81+
kDLFloat8_e4m3b11fnuz = 9
82+
kDLFloat8_e4m3fn = 10
83+
kDLFloat8_e4m3fnuz = 11
84+
kDLFloat8_e5m2 = 12
85+
kDLFloat8_e5m2fnuz = 13
86+
kDLFloat8_e8m0fnu = 14
87+
# FP6 data types
88+
kDLFloat6_e2m3fn = 15
89+
kDLFloat6_e3m2fn = 16
90+
# FP4 data types
91+
kDLFloat4_e2m1fn = 17
7892

7993
def __str__(self):
8094
return {
81-
self.kDLBool: "bool",
8295
self.kDLInt: "int",
8396
self.kDLUInt: "uint",
8497
self.kDLFloat: "float",
98+
self.kDLOpaqueHandle: "void_p",
8599
self.kDLBfloat: "bfloat",
86100
self.kDLComplex: "complex",
87-
self.kDLOpaquePointer: "void_p"
101+
self.kDLBool: "bool",
102+
self.kDLFloat8_e3m4: "float8_e3m4",
103+
self.kDLFloat8_e4m3: "float8_e4m3",
104+
self.kDLFloat8_e4m3b11fnuz: "float8_e4m3b11fnuz",
105+
self.kDLFloat8_e4m3fn: "float8_e4m3fn",
106+
self.kDLFloat8_e4m3fnuz: "float8_e4m3fnuz",
107+
self.kDLFloat8_e5m2: "float8_e5m2",
108+
self.kDLFloat8_e5m2fnuz: "float8_e5m2fnuz",
109+
self.kDLFloat8_e8m0fnu: "float8_e8m0fnu",
110+
self.kDLFloat6_e2m3fn: "float6_e2m3fn",
111+
self.kDLFloat6_e3m2fn: "float6_e3m2fn",
112+
self.kDLFloat4_e2m1fn: "float4_e2m1fn",
88113
}[self.value]
89114

90115

@@ -112,11 +137,22 @@ class DLDataType(ctypes.Structure):
112137
"uint32": (DLDataTypeCode.kDLUInt, 32, 1),
113138
"uint64": (DLDataTypeCode.kDLUInt, 64, 1),
114139
"float16": (DLDataTypeCode.kDLFloat, 16, 1),
115-
"bfloat16": (DLDataTypeCode.kDLBfloat, 16, 1), # Added
140+
"bfloat16": (DLDataTypeCode.kDLBfloat, 16, 1),
116141
"float32": (DLDataTypeCode.kDLFloat, 32, 1),
117142
"float64": (DLDataTypeCode.kDLFloat, 64, 1),
118143
"complex64": (DLDataTypeCode.kDLComplex, 64, 1),
119-
"complex128": (DLDataTypeCode.kDLComplex, 128, 1)
144+
"complex128": (DLDataTypeCode.kDLComplex, 128, 1),
145+
# FP4 types
146+
"float4_e2m1fn": (DLDataTypeCode.kDLFloat4_e2m1fn, 4, 1),
147+
# FP8 types
148+
"float8_e3m4": (DLDataTypeCode.kDLFloat8_e3m4, 8, 1),
149+
"float8_e4m3": (DLDataTypeCode.kDLFloat8_e4m3, 8, 1),
150+
"float8_e4m3b11fnuz": (DLDataTypeCode.kDLFloat8_e4m3b11fnuz, 8, 1),
151+
"float8_e4m3fn": (DLDataTypeCode.kDLFloat8_e4m3fn, 8, 1),
152+
"float8_e4m3fnuz": (DLDataTypeCode.kDLFloat8_e4m3fnuz, 8, 1),
153+
"float8_e5m2": (DLDataTypeCode.kDLFloat8_e5m2, 8, 1),
154+
"float8_e5m2fnuz": (DLDataTypeCode.kDLFloat8_e5m2fnuz, 8, 1),
155+
"float8_e8m0fnu": (DLDataTypeCode.kDLFloat8_e8m0fnu, 8, 1),
120156
}
121157

122158
REV_MAP = {v: k for k, v in TYPE_MAP.items()}
@@ -165,7 +201,7 @@ class DLManagedTensor(ctypes.Structure):
165201

166202

167203
class NcclDataType(ctypes.c_uint8):
168-
# https://github.com/NVIDIA/nccl/blob/559b70f86c190a0d8f67f0d7a0f2c9810dd1e8c7/src/nccl.h.in#L190-L205C3
204+
# https://github.com/NVIDIA/nccl/blob/1e0c869c39bb33f1034cb9920bd2a8a8406f04a3/src/nccl.h.in#L328-L341
169205
ncclInt8 = 0
170206
ncclUint8 = 1
171207
ncclInt32 = 2
@@ -176,6 +212,8 @@ class NcclDataType(ctypes.c_uint8):
176212
ncclFloat32 = 7
177213
ncclFloat64 = 8
178214
ncclBfloat16 = 9
215+
ncclFloat8e4m3 = 10
216+
ncclFloat8e5m2 = 11
179217

180218
TYPE_MAP = {
181219
"bool": ncclUint8,
@@ -189,6 +227,8 @@ class NcclDataType(ctypes.c_uint8):
189227
"float32": ncclFloat32,
190228
"float64": ncclFloat64,
191229
"bfloat16": ncclBfloat16,
230+
"float8_e4m3fn": ncclFloat8e4m3,
231+
"float8_e5m2": ncclFloat8e5m2,
192232
}
193233

194234

@@ -219,9 +259,13 @@ def dlpack_nccl_args(dla) -> tuple[RawDataPointer, int, NcclDataType]:
219259
dltensor.dtype.bits,
220260
dltensor.dtype.lanes,
221261
)
262+
263+
if dtype_key not in DLDataType.REV_MAP:
264+
raise ValueError(f"Unsupported dtype: {dtype_key}")
265+
222266
dtype_name = DLDataType.REV_MAP[dtype_key]
223-
return (
224-
RawDataPointer(data_ptr),
225-
nelems,
226-
NcclDataType(NcclDataType.TYPE_MAP[dtype_name]),
227-
)
267+
if dtype_name not in NcclDataType.TYPE_MAP:
268+
raise ValueError(f"Unsupported dtype: {dtype_name}")
269+
nccl_dtype = NcclDataType.TYPE_MAP[dtype_name]
270+
271+
return RawDataPointer(data_ptr), nelems, NcclDataType(nccl_dtype)

tests/test_dime2.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import jax
19+
import jax.numpy as jnp
20+
import ml_dtypes
21+
import numpy as np
22+
from jax.sharding import PartitionSpec as P
23+
from parameterized import parameterized
24+
25+
import jaxpp.distributed_utils as jppdu
26+
from jaxpp.dime2 import send_or_recv
27+
28+
29+
class SendOrRecvTest(jppdu.JaxDistributedTest):
30+
@parameterized.expand(
31+
[
32+
("float32", jnp.float32, np.float32),
33+
("bfloat16", jnp.bfloat16, ml_dtypes.bfloat16),
34+
("float8_e4m3fn", jnp.float8_e4m3fn, ml_dtypes.float8_e4m3fn),
35+
("float8_e5m2", jnp.float8_e5m2, ml_dtypes.float8_e5m2),
36+
]
37+
)
38+
def test_send_or_recv(self, name, jax_dtype, np_dtype):
39+
process_count = jax.process_count()
40+
process_index = jax.process_index()
41+
local_device_count = jax.local_device_count()
42+
43+
# Use first device from each of the first two processes
44+
devices = np.array(jax.devices()).reshape(process_count, local_device_count)
45+
sender_device = devices[0:1]
46+
receiver_device = devices[1:2]
47+
48+
sender_mesh = jax.sharding.Mesh(sender_device, axis_names=("mpmd", "x"))
49+
receiver_mesh = jax.sharding.Mesh(receiver_device, axis_names=("mpmd", "x"))
50+
51+
pspec = P("x")
52+
sender_sharding = jax.sharding.NamedSharding(sender_mesh, pspec)
53+
receiver_sharding = jax.sharding.NamedSharding(receiver_mesh, pspec)
54+
55+
global_shape = (8,)
56+
expected_values = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np_dtype)
57+
58+
if process_index == 0:
59+
array = jax.device_put(
60+
jnp.array(expected_values, dtype=jax_dtype), sender_sharding
61+
)
62+
63+
[wait_send_finish] = send_or_recv(
64+
[array], [receiver_sharding], is_send=True
65+
)
66+
wait_send_finish()
67+
else:
68+
buffer = jax.device_put(
69+
jnp.zeros(global_shape, dtype=jax_dtype), receiver_sharding
70+
)
71+
72+
[enqueue_recv] = send_or_recv([buffer], [sender_sharding], is_send=False)
73+
received_array = enqueue_recv()
74+
75+
received_values = np.array(received_array)
76+
np.testing.assert_array_equal(
77+
received_values,
78+
expected_values,
79+
err_msg=f"Received data mismatch for dtype {name}",
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
jppdu.distributed_main(unittest.main)

tests/test_dlpack.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import ctypes
17+
import unittest
18+
19+
import jax.numpy as jnp
20+
import ml_dtypes
21+
import numpy as np
22+
from jax._src.dlpack import to_dlpack
23+
from parameterized import parameterized
24+
25+
from jaxpp.dlpack import capsule_name, dlpack_nccl_args
26+
27+
_libcudart = ctypes.CDLL("libcudart.so")
28+
_libcudart.cudaMemcpy.argtypes = [
29+
ctypes.c_void_p, # dst
30+
ctypes.c_void_p, # src
31+
ctypes.c_size_t, # count (bytes)
32+
ctypes.c_int, # kind
33+
]
34+
_libcudart.cudaMemcpy.restype = ctypes.c_int
35+
_cudaMemcpyDeviceToHost = 2
36+
37+
38+
def cuda_memcpy_to_host(device_ptr: int, num_bytes: int) -> bytes:
39+
host_buffer = (ctypes.c_uint8 * num_bytes)()
40+
err = _libcudart.cudaMemcpy(
41+
host_buffer, device_ptr, num_bytes, _cudaMemcpyDeviceToHost
42+
)
43+
if err != 0:
44+
raise RuntimeError(f"cudaMemcpy failed with error {err}")
45+
return bytes(host_buffer)
46+
47+
48+
class TestDlpackExport(unittest.TestCase):
49+
@parameterized.expand(
50+
[
51+
("float32", jnp.float32, np.float32),
52+
("bfloat16", jnp.bfloat16, ml_dtypes.bfloat16),
53+
("float8_e4m3fn", jnp.float8_e4m3fn, ml_dtypes.float8_e4m3fn),
54+
("float8_e5m2", jnp.float8_e5m2, ml_dtypes.float8_e5m2),
55+
]
56+
)
57+
def test_dlpack_export(self, name, jax_dtype, np_dtype):
58+
x = jnp.array([1, 2, 3], dtype=jax_dtype)
59+
capsule = to_dlpack(x)
60+
self.assertEqual(capsule_name(capsule), "dltensor")
61+
data_ptr, count, nccl_dtype = dlpack_nccl_args(capsule)
62+
63+
self.assertEqual(count, 3)
64+
65+
itemsize = np.dtype(np_dtype).itemsize
66+
raw_bytes = cuda_memcpy_to_host(data_ptr, count * itemsize)
67+
values = np.frombuffer(raw_bytes, dtype=np_dtype)
68+
69+
np.testing.assert_array_equal(values, np.array([1, 2, 3], dtype=np_dtype))
70+
71+
def test_unsupported_dtype(self):
72+
x = jnp.array([1, 2, 3], dtype=jnp.float8_e4m3b11fnuz)
73+
capsule = to_dlpack(x)
74+
with self.assertRaises(ValueError) as ctx:
75+
dlpack_nccl_args(capsule)
76+
self.assertIn("Unsupported dtype", str(ctx.exception))
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

0 commit comments

Comments
 (0)