Skip to content

Commit 82bd187

Browse files
authored
[None][chore] update disagg readme and scripts for pipeline parallelism (NVIDIA#6875)
Signed-off-by: raayandhar <rdhar@nvidia.com>
1 parent 6c7813e commit 82bd187

6 files changed

Lines changed: 88 additions & 58 deletions

File tree

examples/disaggregated/slurm/benchmark/README.md

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,29 @@ It takes the following arguments in order:
3434

3535
1. `num_ctx_servers`: Number of context servers.
3636
2. `ctx_tp_size`: Tensor parallel size for context servers.
37-
3. `ctx_batch_size`: Max batch size for context servers.
38-
4. `ctx_max_num_tokens`: Max number of tokens for context servers.
39-
5. `ctx_enable_attention_dp`: `true` or `false` to enable attention DP for context servers.
40-
6. `num_gen_servers`: Number of generation servers.
41-
7. `gen_tp_size`: Tensor parallel size for generation servers.
42-
8. `gen_batch_size`: Max batch size for generation servers.
43-
9. `gen_max_num_tokens`: Max number of tokens for generation servers.
44-
10. `gen_enable_attention_dp`: `true` or `false` to enable attention DP for generation servers.
45-
11. `gen_gpu_memory_fraction`: GPU memory fraction for generation servers.
46-
12. `concurrency_list`: A space-separated list of concurrencies to test (e.g., "1 2 4 8").
47-
13. `sub_file`: A subdirectory name for logs.
37+
3. `ctx_pp_size`: Pipeline parallel size for context servers.
38+
4. `ctx_batch_size`: Max batch size for context servers.
39+
5. `ctx_max_num_tokens`: Max number of tokens for context servers.
40+
6. `ctx_enable_attention_dp`: `true` or `false` to enable attention DP for context servers.
41+
7. `num_gen_servers`: Number of generation servers.
42+
8. `gen_tp_size`: Tensor parallel size for generation servers.
43+
9. `gen_pp_size`: Pipeline parallel size for generation servers.
44+
10. `gen_batch_size`: Max batch size for generation servers.
45+
11. `gen_max_num_tokens`: Max number of tokens for generation servers.
46+
12. `gen_enable_attention_dp`: `true` or `false` to enable attention DP for generation servers.
47+
13. `gen_gpu_memory_fraction`: GPU memory fraction for generation servers.
48+
14. `eplb_num_slots`: Number of slots for eplb.
49+
15. `mtp_size`: Number of nextn layers for MTP.
50+
16. `concurrency`: Concurrency level for benchmarking.
51+
17. `isl`: Input sequence length.
52+
18. `osl`: Output sequence length.
53+
19. `multi_round`: Number of rounds for the benchmark.
54+
20. `streaming`: `true` or `false` for streaming mode.
55+
21. `container_image`: Container image to use.
56+
22. `mounts`: Container mounts.
57+
23. `workdir`: Working directory.
58+
24. `model_dir`: Model directory path.
59+
25. `trtllm_repo`: TensorRT-LLM repository path.
4860

4961
### `gen_yaml.py`
5062

@@ -90,5 +102,5 @@ This script orchestrates the execution of the benchmark client. It waits for the
90102
7. `disaggr_torch.slurm` starts the main `trtllm-serve` process.
91103
8. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready.
92104
9. `run_benchmark.sh` executes the benchmark for each concurrency level specified.
93-
10. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.
105+
10. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.
94106
11. Logs for each run are stored in a subdirectory specified by the `sub_file` parameter.

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,51 @@
1010
# Context servers arguments
1111
num_ctx_servers=${1}
1212
ctx_tp_size=${2}
13-
ctx_batch_size=${3}
14-
ctx_max_num_tokens=${4}
15-
ctx_enable_attention_dp=${5}
16-
ctx_gpu_memory_fraction=${6}
13+
ctx_pp_size=${3}
14+
ctx_batch_size=${4}
15+
ctx_max_num_tokens=${5}
16+
ctx_enable_attention_dp=${6}
17+
ctx_gpu_memory_fraction=${7}
1718

1819
# Generation servers arguments
19-
num_gen_servers=${7}
20-
gen_tp_size=${8}
21-
gen_batch_size=${9}
22-
gen_max_num_tokens=${10}
23-
gen_enable_attention_dp=${11}
24-
gen_gpu_memory_fraction=${12}
20+
num_gen_servers=${8}
21+
gen_tp_size=${9}
22+
gen_pp_size=${10}
23+
gen_batch_size=${11}
24+
gen_max_num_tokens=${12}
25+
gen_enable_attention_dp=${13}
26+
gen_gpu_memory_fraction=${14}
2527

2628
# Other arguments
27-
eplb_num_slots=${13}
28-
mtp_size=${14}
29+
eplb_num_slots=${15}
30+
mtp_size=${16}
2931

3032
# Benchmarking arguments
31-
concurrency=${15}
32-
isl=${16}
33-
osl=${17}
34-
multi_round=${18}
35-
streaming=${19}
33+
concurrency=${17}
34+
isl=${18}
35+
osl=${19}
36+
multi_round=${20}
37+
streaming=${21}
3638

3739
# User specific arguments
38-
container_image=${20}
39-
mounts=${21}
40-
workdir=${22}
41-
model_dir=${23}
42-
benchmark_mode=${24}
43-
trtllm_repo=${25}
40+
container_image=${22}
41+
mounts=${23}
42+
workdir=${24}
43+
model_dir=${25}
44+
benchmark_mode=${26}
45+
trtllm_repo=${27}
4446

4547
echo "================= parameters ================="
4648
echo "num_ctx_servers: ${num_ctx_servers}"
4749
echo "ctx_tp_size: ${ctx_tp_size}"
50+
echo "ctx_pp_size: ${ctx_pp_size}"
4851
echo "ctx_batch_size: ${ctx_batch_size}"
4952
echo "ctx_max_num_tokens: ${ctx_max_num_tokens}"
5053
echo "ctx_enable_attention_dp: ${ctx_enable_attention_dp}"
5154
echo "ctx_gpu_memory_fraction: ${ctx_gpu_memory_fraction}"
5255
echo "num_gen_servers: ${num_gen_servers}"
5356
echo "gen_tp_size: ${gen_tp_size}"
57+
echo "gen_pp_size: ${gen_pp_size}"
5458
echo "gen_batch_size: ${gen_batch_size}"
5559
echo "gen_max_num_tokens: ${gen_max_num_tokens}"
5660
echo "gen_enable_attention_dp: ${gen_enable_attention_dp}"
@@ -83,8 +87,8 @@ full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_si
8387

8488
echo "concurrency: ${concurrency}"
8589

86-
ctx_gpus=$((num_ctx_servers * ctx_tp_size))
87-
gen_gpus=$((num_gen_servers * gen_tp_size))
90+
ctx_gpus=$((num_ctx_servers * ctx_tp_size * ctx_pp_size))
91+
gen_gpus=$((num_gen_servers * gen_tp_size * gen_pp_size))
8892

8993
echo "enable_attention_dp: ${ctx_enable_attention_dp}, ${gen_enable_attention_dp}, gpu_memory_fraction: ${gen_gpu_memory_fraction}"
9094

@@ -132,13 +136,15 @@ srun -l --container-name=${container_name} \
132136
--model ${model_dir} \
133137
--num_ctx_servers ${num_ctx_servers} \
134138
--ctx_tp_size ${ctx_tp_size} \
139+
--ctx_pp_size ${ctx_pp_size} \
135140
--ctx_batch_size ${ctx_batch_size} \
136141
--ctx_max_num_tokens ${ctx_max_num_tokens} \
137142
--ctx_max_seq_len ${ctx_max_seq_len} \
138143
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
139144
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
140145
--num_gen_servers ${num_gen_servers} \
141146
--gen_tp_size ${gen_tp_size} \
147+
--gen_pp_size ${gen_pp_size} \
142148
--gen_batch_size ${gen_batch_size} \
143149
--gen_max_num_tokens ${gen_max_num_tokens} \
144150
--gen_max_seq_len ${gen_max_seq_len} \

examples/disaggregated/slurm/benchmark/gen_yaml.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,15 @@ def gen_config_file(config_path: str,
123123
model_path: str,
124124
num_ctx_servers: int,
125125
ctx_tp_size: int,
126+
ctx_pp_size: int,
126127
ctx_batch_size: int,
127128
ctx_max_num_tokens: int,
128129
ctx_max_seq_len: int,
129130
ctx_free_gpu_memory_fraction: float,
130131
ctx_enable_attention_dp: bool,
131132
num_gen_servers: int,
132133
gen_tp_size: int,
134+
gen_pp_size: int,
133135
gen_batch_size: int,
134136
gen_max_num_tokens: int,
135137
gen_max_seq_len: int,
@@ -148,13 +150,15 @@ def gen_config_file(config_path: str,
148150
model_path: Path to the model
149151
num_ctx_servers: Number of context servers
150152
ctx_tp_size: Tensor parallel size for context servers
153+
ctx_pp_size: Pipeline parallel size for context servers
151154
ctx_batch_size: Batch size for context servers
152155
ctx_max_num_tokens: Max number of tokens for context servers
153156
ctx_max_seq_len: Max sequence length for context servers
154157
ctx_free_gpu_memory_fraction: Free GPU memory fraction for context servers
155158
ctx_enable_attention_dp: Enable attention DP for context servers
156159
num_gen_servers: Number of generation servers
157160
gen_tp_size: Tensor parallel size for generation servers
161+
gen_pp_size: Pipeline parallel size for generation servers
158162
gen_batch_size: Batch size for generation servers
159163
gen_max_num_tokens: Max number of tokens for generation servers
160164
gen_enable_attention_dp: Enable attention DP for generation servers
@@ -187,7 +191,7 @@ def gen_config_file(config_path: str,
187191
'tensor_parallel_size': ctx_tp_size,
188192
'moe_expert_parallel_size': ctx_tp_size,
189193
'enable_attention_dp': ctx_enable_attention_dp,
190-
'pipeline_parallel_size': 1,
194+
'pipeline_parallel_size': ctx_pp_size,
191195
'print_iter_log': True,
192196
'disable_overlap_scheduler': True,
193197
'kv_cache_config': {
@@ -205,7 +209,7 @@ def gen_config_file(config_path: str,
205209
'tensor_parallel_size': gen_tp_size,
206210
'moe_expert_parallel_size': gen_tp_size,
207211
'enable_attention_dp': gen_enable_attention_dp,
208-
'pipeline_parallel_size': 1,
212+
'pipeline_parallel_size': gen_pp_size,
209213
'max_batch_size': gen_batch_size,
210214
'max_num_tokens': gen_max_num_tokens,
211215
'max_seq_len': gen_max_seq_len,
@@ -237,15 +241,15 @@ def gen_config_file(config_path: str,
237241

238242
# Generate URLs for context and generation servers
239243
ctx_urls, task_nodes_offset = generate_urls("ctx", num_ctx_servers,
240-
ctx_tp_size, 1,
244+
ctx_tp_size, ctx_pp_size,
241245
max_tasks_per_node, nodes,
242246
task_nodes, node_ports)
243247
if num_ctx_servers > 0:
244248
config['context_servers']['urls'] = ctx_urls
245249

246-
gen_urls, _ = generate_urls("gen", num_gen_servers, gen_tp_size, 1,
247-
max_tasks_per_node, nodes, task_nodes,
248-
node_ports, task_nodes_offset)
250+
gen_urls, _ = generate_urls("gen", num_gen_servers, gen_tp_size,
251+
gen_pp_size, max_tasks_per_node, nodes,
252+
task_nodes, node_ports, task_nodes_offset)
249253
config['generation_servers']['urls'] = gen_urls
250254

251255
# set the hostname to the first node
@@ -300,6 +304,10 @@ def gen_config_file(config_path: str,
300304
type=int,
301305
required=True,
302306
help="Tensor parallel size for context servers")
307+
parser.add_argument("--ctx_pp_size",
308+
type=int,
309+
default=1,
310+
help="Pipeline parallel size for context servers")
303311
parser.add_argument("--ctx_batch_size",
304312
type=int,
305313
required=True,
@@ -328,6 +336,10 @@ def gen_config_file(config_path: str,
328336
type=int,
329337
required=True,
330338
help="Tensor parallel size for generation servers")
339+
parser.add_argument("--gen_pp_size",
340+
type=int,
341+
default=1,
342+
help="Pipeline parallel size for generation servers")
331343
parser.add_argument("--gen_batch_size",
332344
type=int,
333345
required=True,
@@ -372,11 +384,11 @@ def gen_config_file(config_path: str,
372384
args = parser.parse_args()
373385

374386
gen_config_file(args.config, args.model, args.num_ctx_servers,
375-
args.ctx_tp_size, args.ctx_batch_size,
387+
args.ctx_tp_size, args.ctx_pp_size, args.ctx_batch_size,
376388
args.ctx_max_num_tokens, args.ctx_max_seq_len,
377389
args.ctx_free_gpu_memory_fraction,
378390
args.ctx_enable_attention_dp, args.num_gen_servers,
379-
args.gen_tp_size, args.gen_batch_size,
391+
args.gen_tp_size, args.gen_pp_size, args.gen_batch_size,
380392
args.gen_max_num_tokens, args.gen_max_seq_len,
381393
args.gen_enable_attention_dp, args.gen_gpu_memory_fraction,
382394
args.eplb_num_slots, args.mtp_size, args.worker_start_port,

examples/disaggregated/slurm/benchmark/submit.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ streaming=true
2121
benchmark_mode=e2e
2222

2323
args=(
24-
1 4 4 4480 true "0.75" # Context servers arguments
25-
1 8 1024 1024 true "0.8" # Generation servers arguments
26-
0 0 # Other arguments
27-
$concurrency # Benchmarking arguments
24+
1 4 1 4 4480 true "0.75" # Context servers arguments
25+
1 8 1 1024 1024 true "0.8" # Generation servers arguments
26+
0 0 # Other arguments
27+
$concurrency # Benchmarking arguments
2828
$isl
2929
$osl
3030
$multi_round
3131
$streaming
32-
$container_image # User specific arguments
32+
$container_image # User specific arguments
3333
$mounts
3434
$workdir
3535
$model_dir

examples/wide_ep/slurm_scripts/submit_e2e.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ for b in 1 64 1024; do
3030
ntasks=$((total_node_num * ntasks_per_node))
3131

3232
args=(
33-
${ctx_num} 4 4 4480 true "0.85" # Context servers arguments
34-
1 16 1024 1024 true "0.7" # Generation servers arguments
33+
${ctx_num} 4 1 4 4480 true "0.85" # Context servers arguments
34+
1 16 1 1024 1024 true "0.7" # Generation servers arguments
3535
$eplb_num_slots $mtp_size # Other arguments
3636
$concurrency # Benchmarking arguments
3737
$isl
@@ -68,8 +68,8 @@ for b in 512; do
6868
eplb_num_slots=288
6969

7070
args=(
71-
${ctx_num} 4 4 4480 true "0.85" # Context servers arguments
72-
1 32 1024 1024 true "0.7" # Generation servers arguments
71+
${ctx_num} 4 1 4 4480 true "0.85" # Context servers arguments
72+
1 32 1 1024 1024 true "0.7" # Generation servers arguments
7373
$eplb_num_slots $mtp_size # Other arguments
7474
$concurrency # Benchmarking arguments
7575
$isl

examples/wide_ep/slurm_scripts/submit_gen_only.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ for b in 1 64 1024; do
3030
ntasks=$((total_node_num * ntasks_per_node))
3131

3232
args=(
33-
${ctx_num} 4 4 4480 true "0.85" # Context servers arguments
34-
1 16 1024 1024 true "0.7" # Generation servers arguments
33+
${ctx_num} 4 1 4 4480 true "0.85" # Context servers arguments
34+
1 16 1 1024 1024 true "0.7" # Generation servers arguments
3535
$eplb_num_slots $mtp_size # Other arguments
3636
$concurrency # Benchmarking arguments
3737
$isl
@@ -68,8 +68,8 @@ for b in 512; do
6868
eplb_num_slots=288
6969

7070
args=(
71-
${ctx_num} 4 4 4480 true "0.85" # Context servers arguments
72-
1 32 1024 1024 true "0.7" # Generation servers arguments
71+
${ctx_num} 4 1 4 4480 true "0.85" # Context servers arguments
72+
1 32 1 1024 1024 true "0.7" # Generation servers arguments
7373
$eplb_num_slots $mtp_size # Other arguments
7474
$concurrency # Benchmarking arguments
7575
$isl

0 commit comments

Comments
 (0)