Skip to content

Commit fc5e043

Browse files
committed
avoid fsdp2 shard of lm_head to enable fused_kernels
1 parent f63061a commit fc5e043

10 files changed

Lines changed: 1071 additions & 13 deletions

File tree

apertus/launch/multinode_async_sandbox/_verl_training.sbatch

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
set -xeuo pipefail
1414

1515
clear_inherited_pyxis_options() {
16+
set +x
1617
local name
1718
while IFS='=' read -r name _; do
1819
case "${name}" in
1920
SLURM_SPANK__SLURM_SPANK_OPTION_pyxis_*) unset "${name}" ;;
2021
esac
2122
done < <(env)
23+
set -x
2224
}
2325

2426
clear_inherited_pyxis_options

apertus/launch/multinode_async_sandbox/launch.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ ASYNC_STEADY_WARMUP_STEPS="${ASYNC_STEADY_WARMUP_STEPS:-}"
7373
###############################################################################
7474

7575
# Set REASONING_GYM_DIR="" to install reasoning-gym from PyPI.
76-
REASONING_GYM_DIR="${REASONING_GYM_DIR:-${SCRATCH_HOME}/projects/r-gym}"
77-
TOOL_GYM_DIR="${TOOL_GYM_DIR:-${SCRATCH_HOME}/projects/tool-gym}"
78-
TOOL_GYM_FUNCTION_TOOL_PATH="${TOOL_GYM_FUNCTION_TOOL_PATH:-/capstor/store/cscs/swissai/infra01/reasoning/data/RL-prod/toolgym_test_v2/apertus_function_tools.py}"
76+
REASONING_GYM_DIR=""
77+
TOOL_GYM_DIR=""
78+
TOOL_GYM_FUNCTION_TOOL_PATH="${TOOL_GYM_FUNCTION_TOOL_PATH:-/capstor/store/cscs/swissai/infra01/reasoning/data/RL-prod/toolgym_test_v3/apertus_function_tools_v3.py}"
7979
SANDBOX_BACKEND="kubernetes" # kubernetes, codegym, or none
8080
KUBERNETES_SANDBOX_URL="https://sandbox-dev.swissai.svc.cscs.ch"
8181
CODE_GYM_DIR="" # ${SCRATCH_HOME}/projects/code-gym} # Not needed if using kubernetes

tests/checkpoint_engine/test_correctness_on_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def test_nccl_checkpoint_engine(
4242
num_gpus_per_node=_ngpus,
4343
bucket_size_mb=128,
4444
check_allclose=True,
45-
model_path="~/models/Qwen/Qwen3-8B-Base",
45+
model_path="swiss-ai/Apertus-8B-Instruct-2509",
4646
):
4747
model_path = os.path.expanduser(model_path)
4848
ray.init(

tests/checkpoint_engine/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: Check
3939
if torch.distributed.get_rank() == 0:
4040
engine_kwargs["is_master"] = True
4141
self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs)
42+
4243

4344
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
4445
async def update_weights(self, global_steps: int = None, mode: str = "auto"):
@@ -142,7 +143,7 @@ def check_weights(self):
142143
def create_trainer_worker_group(
143144
resource_pool: RayResourcePool, model_config: HFModelConfig, checkpoint_engine_config: CheckpointEngineConfig
144145
) -> RayWorkerGroup:
145-
engine_config = FSDPEngineConfig(forward_only=True, fsdp_size=resource_pool.world_size, strategy="fsdp")
146+
engine_config = FSDPEngineConfig(forward_only=True, fsdp_size=resource_pool.world_size, strategy="fsdp2")
146147
trainer_config = TrainingWorkerConfig(
147148
model_type="language_model",
148149
model_config=model_config,

verl/experimental/fully_async_policy/config/async.yaml

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,29 @@ actor_rollout_ref:
6666
use_remove_padding: true
6767
enable_gradient_checkpointing: true
6868
use_shm: false
69+
use_fused_kernels: true
70+
fused_kernel_options:
71+
impl_backend: torch
6972
override_config:
7073
attn_implementation: flash_attention_3
7174

7275
actor:
7376
use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs,true}
7477
ppo_mini_batch_size: 128 # NOTE: train_batch_size = require_batches * ppo_mini_batch_size
7578
use_dynamic_bsz: true
76-
ppo_max_token_len_per_gpu: 32768
79+
ppo_max_token_len_per_gpu: 16384
7780
ulysses_sequence_parallel_size: 1
7881
entropy_from_logits_with_chunking: false
7982
entropy_checkpointing: false
80-
81-
83+
profiler:
84+
enable: False
85+
all_ranks: False
86+
ranks: [0]
87+
tool: torch
88+
tool_config:
89+
torch:
90+
contents: [cpu, memory, cuda, shapes, stack]
91+
discrete: True
8292
optim:
8393
optimizer: _AdamW
8494
lr: 1e-6
@@ -92,6 +102,9 @@ actor_rollout_ref:
92102
strategy: fsdp2
93103
param_offload: false
94104
optimizer_offload: false
105+
reshard_after_forward: false
106+
forward_prefetch: true
107+
fsdp_size: 4
95108
entropy_from_logits_with_chunking: ${oc.select:actor_rollout_ref.actor.entropy_from_logits_with_chunking,false}
96109
entropy_checkpointing: ${oc.select:actor_rollout_ref.actor.entropy_checkpointing,false}
97110
model_dtype: bf16
@@ -164,6 +177,14 @@ trainer:
164177
rollout_data_dir: ${trainer.default_local_dir}/rollout/
165178
validation_data_dir: ${trainer.default_local_dir}/validation/
166179

180+
global_profiler:
181+
steps: [4]
182+
save_path: "/iopsstor/scratch/cscs/atazza/verl_profile"
183+
tool: torch
184+
tool_config:
185+
torch:
186+
contents: [cpu, memory, cuda, stack]
187+
167188
ray_kwargs:
168189
ray_init:
169190
num_cpus: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK,null}}

0 commit comments

Comments
 (0)