Skip to content

Commit 2c19674

Browse files
authored
[bugfix] Fix duplicate 'load_format' argument being passed in rollout (#7312)
1 parent 6662245 commit 2c19674

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

swift/llm/infer/rollout.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,13 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs):
394394
engine_kwargs = kwargs.get('engine_kwargs', {})
395395
# for RL rollout model weight sync
396396
engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'})
397-
# Use load_format from engine_kwargs if provided, otherwise default to 'dummy'
398-
if 'load_format' not in engine_kwargs:
399-
engine_kwargs['load_format'] = 'dummy'
397+
398+
# For RL rollout, we use 'dummy' load_format to prevent vLLM from loading weights from disk,
399+
# as they will be synced from the trainer process.
400+
# This will accelerate the rollout speed.
401+
load_format = engine_kwargs.pop('load_format', 'dummy')
402+
kwargs['load_format'] = load_format
403+
400404
if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1:
401405
engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size
402406
kwargs['engine_kwargs'] = engine_kwargs

0 commit comments

Comments
 (0)