Skip to content

Commit 454f581

Browse files
pathfinder-pfpathfinder-fp
andauthored
feat:add device indexes (#470)
Co-authored-by: pathfinder-fp <[email protected]>
1 parent 58f8e07 commit 454f581

File tree

7 files changed

+87
-6
lines changed

7 files changed

+87
-6
lines changed

.github/workflows/pr-test.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,25 @@ jobs:
9090
unit-test-4-tpu:
9191
needs: [check-changes]
9292
if: github.event.pull_request.draft == false && needs.check-changes.outputs.main_package == 'true'
93-
#runs-on: arc-runner-v6e-1
94-
runs-on: ubuntu-slim
93+
runs-on: arc-runner-v6e-4
9594
strategy:
9695
fail-fast: false
9796
matrix:
9897
part: [0, 1]
9998
steps:
100-
# TODO: please rename when implement unit test
101-
- name: Mock
99+
- name: Checkout code
100+
uses: actions/checkout@v5
101+
- name: Run unit test
102102
timeout-minutes: 30
103103
env:
104104
SGLANG_JAX_IS_IN_CI: true
105-
run: echo ${{ matrix.part }}
105+
run: |
106+
python3.12 -m venv .venv
107+
source .venv/bin/activate
108+
pip install uv
109+
uv pip install -e "python[all]"
110+
bash scripts/killall_sglang.sh
111+
python test/srt/run_suite.py --suite unit-test-tpu-v6e-4 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2
106112
107113
# =============================================== e2e-test ===============================================
108114
e2e-test-1-tpu:

python/sgl_jax/srt/managers/scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,11 @@ def __init__(
218218
# init distribution
219219
if self.nnodes > 1:
220220
jax.distributed.initialize(server_args.dist_init_addr, self.nnodes, self.node_rank)
221-
self.mesh = create_device_mesh(ici_parallelism=[-1, self.tp_size], dcn_parallelism=[1, 1])
221+
self.mesh = create_device_mesh(
222+
ici_parallelism=[-1, self.tp_size],
223+
dcn_parallelism=[1, 1],
224+
device_indexes=server_args.device_indexes,
225+
)
222226

223227
TpWorkerClass = ModelWorkerClient if self.enable_overlap else ModelWorker
224228

python/sgl_jax/srt/server_args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class ServerArgs:
6666

6767
# Runtime options
6868
device: str | None = None
69+
device_indexes: list[int] | None = None
6970
tp_size: int = 1
7071
ep_size: int = 1
7172
stream_interval: int = 1
@@ -216,6 +217,10 @@ def __post_init__(self):
216217
"1" if self.enable_deterministic_sampling else "0"
217218
)
218219

220+
if self.nnodes > 1 and self.device_indexes is not None:
221+
logger.warning("In a multi-machine scenario, device_indexes will be set to None.")
222+
self.device_indexes = None
223+
219224
@staticmethod
220225
def add_cli_args(parser: argparse.ArgumentParser):
221226
# Model and tokenizer
@@ -499,6 +504,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
499504
default=ServerArgs.device,
500505
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
501506
)
507+
508+
parser.add_argument(
509+
"--device-indexes",
510+
type=int,
511+
nargs="+",
512+
help="The device indexes to use build mesh. Defaults is all if not specified.",
513+
)
514+
502515
parser.add_argument(
503516
"--tensor-parallel-size",
504517
"--tp-size",

python/sgl_jax/srt/utils/mesh_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def create_device_mesh(
99
ici_parallelism: Sequence[int],
1010
dcn_parallelism: Sequence[int],
1111
devices=None,
12+
device_indexes: list[int] = None,
1213
num_slices: int = 1,
1314
allow_split_physical_axes: bool = True,
1415
use_explicit_sharding: bool = True,
@@ -17,6 +18,13 @@ def create_device_mesh(
1718
if devices is None:
1819
devices = jax.devices()
1920

21+
if device_indexes is not None:
22+
max_index = max(device_indexes)
23+
if max_index >= len(devices):
24+
raise RuntimeError("Device index out of range")
25+
devices_dict = {device.id: device for device in devices}
26+
devices = [devices_dict.get(i) for i in list(set(device_indexes))]
27+
2028
ici_parallelism = fill_unspecified_parallelism(ici_parallelism, len(devices))
2129
if num_slices > 1:
2230
dcn_parallelism = fill_unspecified_parallelism(dcn_parallelism, num_slices)

python/sgl_jax/test/test_mesh.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
3+
from sgl_jax.srt.utils.mesh_utils import create_device_mesh
4+
from sgl_jax.test.test_utils import CustomTestCase
5+
6+
7+
class TestMesh(CustomTestCase):
8+
def test_mesh_with_no_device_indexes(self):
9+
mesh = create_device_mesh(ici_parallelism=[1, -1], dcn_parallelism=[1, 1])
10+
self.assertEqual(mesh.shape.get("data"), 1, "dp should be 1")
11+
self.assertEqual(mesh.shape.get("tensor"), 4, "tp should be 1")
12+
13+
def test_mesh_with_device_indexes(self):
14+
mesh = create_device_mesh(
15+
ici_parallelism=[1, -1], dcn_parallelism=[1, 1], device_indexes=[0, 1]
16+
)
17+
self.assertEqual(mesh.shape.get("data"), 1, "dp should be 1")
18+
self.assertEqual(mesh.shape.get("tensor"), 2, "tp should be 1")
19+
20+
def test_mesh_with_duplicated_device_indexes(self):
21+
mesh = create_device_mesh(
22+
ici_parallelism=[1, -1], dcn_parallelism=[1, 1], device_indexes=[0, 1, 0, 1]
23+
)
24+
self.assertEqual(mesh.shape.get("data"), 1, "dp should be 1")
25+
self.assertEqual(mesh.shape.get("tensor"), 2, "tp should be 1")
26+
27+
def test_mesh_with_large_device_indexes(self):
28+
try:
29+
_ = create_device_mesh(
30+
ici_parallelism=[1, -1], dcn_parallelism=[1, 1], device_indexes=[0, 4]
31+
)
32+
except Exception as e:
33+
self.assertTrue(
34+
isinstance(e, RuntimeError), "the device indexes have exceeded the len of devices"
35+
)
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main()

python/sgl_jax/test/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def create_device_mesh(
6060
ici_parallelism: Sequence[int],
6161
dcn_parallelism: Sequence[int],
6262
devices=None,
63+
device_indexes: list[int] = None,
6364
num_slices: int = 1,
6465
allow_split_physical_axes: bool = True,
6566
use_explicit_sharding: bool = True,
@@ -68,6 +69,13 @@ def create_device_mesh(
6869
if devices is None:
6970
devices = jax.devices()
7071

72+
if device_indexes is not None:
73+
max_index = max(device_indexes)
74+
if max_index >= len(devices):
75+
raise RuntimeError("Device index out of range")
76+
devices_dict = {device.id: device for device in devices}
77+
devices = [devices_dict.get(i) for i in list(set(device_indexes))]
78+
7179
ici_parallelism = fill_unspecified_parallelism(ici_parallelism, len(devices))
7280
if num_slices > 1:
7381
dcn_parallelism = fill_unspecified_parallelism(dcn_parallelism, num_slices)

test/srt/run_suite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def run_one_file(filename):
168168
TestFile("python/sgl_jax/test/speculative/test_eagle_tree_build.py", 1),
169169
TestFile("python/sgl_jax/test/speculative/test_eagle_utils.py", 1),
170170
],
171+
"unit-test-tpu-v6e-4": [
172+
TestFile("python/sgl_jax/test/test_mesh.py", 1),
173+
],
171174
"kernel-performance-test-tpu-v6e-1": [
172175
TestFile("benchmark/kernels/flash_attention/bench_flashattention.py", 5),
173176
TestFile("benchmark/kernels/megablox_gmm/bench_megablox_gmm.py", 2),

0 commit comments

Comments
 (0)