Skip to content

Commit 5ad352e

Browse files
committed
reuse max_num_seqs for max_batch_size
1 parent ff77ac6 commit 5ad352e

File tree

5 files changed

+2
-8
lines changed

5 files changed

+2
-8
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ actor_rollout_ref:
216216
data_parallel_size: 1
217217
expert_parallel_size: 1
218218
pipeline_model_parallel_size: 1
219-
max_batch_size: 256
220219
max_num_batched_tokens: 8192
221220
max_model_len: null
222221
max_num_seqs: 1024

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ actor_rollout_ref:
207207
data_parallel_size: 1
208208
expert_parallel_size: 1
209209
pipeline_model_parallel_size: 1
210-
max_batch_size: 256
211210
max_num_batched_tokens: 8192
212211
max_model_len: null
213212
max_num_seqs: 1024

verl/trainer/config/rollout/rollout.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ expert_parallel_size: 1
5858
# PP size for rollout.
5959
pipeline_model_parallel_size: 1
6060

61-
# max batch size for rollout
62-
max_batch_size: 256
63-
6461
# max number of tokens in a batch
6562
max_num_batched_tokens: 8192
6663

verl/workers/config/rollout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ class RolloutConfig(BaseConfig):
151151
max_num_batched_tokens: int = 8192
152152
logprobs_mode: Optional[str] = "processed_logprobs"
153153
scheduling_policy: Optional[str] = "fcfs"
154-
max_batch_size: int = 256
155154

156155
# TODO: enable train_kwargs
157156
# train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig)

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def launch_server(self):
109109
cuda_graph_config = CudaGraphConfig(
110110
enable_padding=True,
111111
batch_sizes=self.config.cudagraph_capture_sizes,
112-
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_batch_size,
112+
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
113113
)
114114

115115
per_worker_gpu_share = 1.0 / self.max_colocate_count
@@ -122,7 +122,7 @@ async def launch_server(self):
122122
"kv_cache_config": kv_cache_config,
123123
"cuda_graph_config": cuda_graph_config,
124124
"max_seq_len": self.config.max_model_len,
125-
"max_batch_size": self.config.max_batch_size,
125+
"max_batch_size": self.config.max_num_seqs,
126126
"max_num_tokens": self.config.max_num_batched_tokens,
127127
"tensor_parallel_size": self.config.tensor_model_parallel_size,
128128
"trust_remote_code": self.model_config.trust_remote_code,

0 commit comments

Comments
 (0)