Skip to content

Commit 0a5e856

Browse files
authored
[perf] feat: support profiler in model engine and sft trainer (#4749)
### What does this PR do? - As title ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent a090cd8 commit 0a5e856

File tree

11 files changed

+178
-29
lines changed

11 files changed

+178
-29
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
2+
_target_: verl.utils.profiler.ProfilerConfig
3+
4+
# profiler tool, default same as profiler.tool in global config
5+
# choices: nsys, npu, torch
6+
tool: torch
7+
8+
# whether enable profile on Actor
9+
enable: False
10+
11+
# Whether to profile all ranks.
12+
all_ranks: False
13+
14+
# The ranks that will be profiled. [] or [0,1,...]
15+
ranks: []
16+
17+
# profile results saving path
18+
save_path: "outputs/profile"
19+
20+
tool_config:
21+
npu:
22+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
23+
_target_: verl.utils.profiler.config.NPUToolConfig
24+
25+
# Contents to profile, can be empty
26+
# options: npu, cpu, memory, shapes, module, stack
27+
contents: [ ]
28+
29+
# Collection level, optional values: level_none, level0, level1, level2.
30+
level: "level0"
31+
32+
# Whether to automatically parse the data.
33+
analysis: True
34+
35+
# True for each task has its own database, False for all tasks in one training step share one database.
36+
discrete: False
37+
38+
name: npu
39+
40+
41+
nsys:
42+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
43+
_target_: verl.utils.profiler.config.NsightToolConfig
44+
45+
# True for each task has its own database, False for all tasks in one training step share one database.
46+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
47+
48+
name: nsight
49+
50+
torch:
51+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
52+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
53+
54+
# start profile mini-batch in training
55+
# NOTICE: different with global steps config which refers to iteration
56+
# This field only related with mini-batch
57+
step_start: 0
58+
59+
# stop profile mini-batch in training
60+
step_end: null
61+
62+
# manual save
63+
manual_save: True
64+
65+
name: torch
66+
67+
torch_memory:
68+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
69+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
70+
71+
# Maximum number of memory allocation entries to track
72+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
73+
74+
# Stack trace depth for memory allocations
75+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
76+
77+
name: torch_memory

verl/trainer/config/sft_trainer_engine.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ defaults:
1010
- model@model: hf_model
1111
- engine@engine: fsdp
1212
- optim@optim: fsdp
13+
- profiler@profiler: profiler
1314
- _self_
1415

1516
data:
@@ -78,3 +79,5 @@ trainer:
7879

7980
nnodes: 1
8081
n_gpus_per_node: 1
82+
83+
profile_interval: [-1, -1]

verl/trainer/sft_trainer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def _build_config(self):
9696
self.engine_config = omega_conf_to_dataclass(self.config.engine)
9797
self.optimizer_config = omega_conf_to_dataclass(self.config.optim)
9898
self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint)
99+
self.profiler_config = omega_conf_to_dataclass(self.config.profiler)
100+
101+
# check profile interval
102+
self.profiler_interval = self.config.trainer.profile_interval
103+
self._validate_profiler_interval()
104+
105+
def _validate_profiler_interval(self):
106+
assert len(self.profiler_interval) == 2
107+
self.start_profile_step = self.profiler_interval[0]
108+
self.end_profile_step = self.profiler_interval[1]
109+
assert self.end_profile_step >= self.start_profile_step
110+
if self.start_profile_step < 0:
111+
assert self.end_profile_step < 0
99112

100113
def _build_engine(self):
101114
from verl.workers.engine_workers import TrainingWorkerConfig
@@ -109,6 +122,7 @@ def _build_engine(self):
109122
engine_config=self.engine_config,
110123
optimizer_config=self.optimizer_config,
111124
checkpoint_config=self.checkpoint_config,
125+
profiler_config=self.profiler_config,
112126
)
113127

114128
self.training_client = TrainingWorker(config=config)
@@ -303,9 +317,15 @@ def fit(self):
303317

304318
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens)
305319

320+
# start profile in SPMD mode
321+
if global_step == self.start_profile_step:
322+
self.training_client.start_profile()
306323
# train for on batch
307324
output = self.training_client.train_batch(data=data)
308325

326+
if global_step == self.end_profile_step:
327+
self.training_client.stop_profile()
328+
309329
if self.engine.is_mp_src_rank_with_outputs():
310330
metrics = tu.get(output, "metrics")
311331

verl/trainer/sft_trainer_ray.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def _build_config(self):
9090
self.engine_config = omega_conf_to_dataclass(self.config.engine)
9191
self.optimizer_config = omega_conf_to_dataclass(self.config.optim)
9292
self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint)
93+
self.profiler_config = omega_conf_to_dataclass(self.config.profiler)
94+
95+
# check profile interval
96+
self.profiler_interval = self.config.trainer.profile_interval
97+
self._validate_profiler_interval()
98+
99+
def _validate_profiler_interval(self):
100+
assert len(self.profiler_interval) == 2
101+
self.start_profile_step = self.profiler_interval[0]
102+
self.end_profile_step = self.profiler_interval[1]
103+
assert self.end_profile_step >= self.start_profile_step
104+
if self.start_profile_step < 0:
105+
assert self.end_profile_step < 0
93106

94107
def _build_engine(self):
95108
from verl.workers.engine_workers import TrainingWorkerConfig
@@ -103,6 +116,7 @@ def _build_engine(self):
103116
engine_config=self.engine_config,
104117
optimizer_config=self.optimizer_config,
105118
checkpoint_config=self.checkpoint_config,
119+
profiler_config=self.profiler_config,
106120
)
107121

108122
# create resource pool and worker group
@@ -279,10 +293,16 @@ def fit(self):
279293

280294
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens)
281295

296+
# start profile in SPMD mode
297+
if global_step == self.start_profile_step:
298+
self.training_client.start_profile()
282299
# train for on batch
283300
output = self.training_client.train_batch(data)
284301
output = output.get()
285302

303+
if global_step == self.end_profile_step:
304+
self.training_client.stop_profile()
305+
286306
metrics = tu.get(output, "metrics")
287307

288308
# TODO: we can actual accumulate metrics for N steps and perform aggregate metrics

verl/utils/profiler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..device import is_npu_available
1616
from ..import_utils import is_nvtx_available
1717
from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
18-
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig
18+
from .profile import DistProfiler, DistProfilerExtension, Profiler, ProfilerConfig
1919

2020
# Select marker implementations by availability, but keep DistProfiler as our dispatcher
2121
if is_nvtx_available():
@@ -34,6 +34,7 @@
3434
"mark_annotate",
3535
"DistProfiler",
3636
"DistProfilerExtension",
37+
"Profiler",
3738
"ProfilerConfig",
3839
"simple_timer",
3940
"marked_timer",

verl/utils/profiler/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class NsightToolConfig(BaseConfig):
2727

2828
"True for each task has its own database, False for all tasks in one training step share one database."
2929
discrete: bool = False
30+
name: str = "nsight"
3031

3132
def __post_init__(self) -> None:
3233
pass
@@ -43,6 +44,8 @@ class TorchProfilerToolConfig(BaseConfig):
4344

4445
step_start: int = -1
4546
step_end: int = -1
47+
manual_save: bool = True
48+
name: str = "torch"
4649

4750
def __post_init__(self) -> None:
4851
"""config validation logics go here"""
@@ -61,6 +64,7 @@ class TorchMemoryToolConfig(BaseConfig):
6164

6265
trace_alloc_max_entries: int = 100_000
6366
stack_depth: int = 32
67+
name: str = "torch_memory"
6468

6569
def __post_init__(self) -> None:
6670
"""config validation logics go here"""
@@ -87,6 +91,8 @@ class NPUToolConfig(NsightToolConfig):
8791
# Whether to automatically parse the data.
8892
analysis: bool = False
8993

94+
name: str = "npu"
95+
9096
def __post_init__(self) -> None:
9197
"""config validation logics go here"""
9298
assert isinstance(self.contents, list), f"Profiler contents must be of type list, got {type(self.contents)}"

verl/utils/profiler/profile.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@ class Profiler:
4040
config: Configuration object containing profiling parameters
4141
"""
4242

43-
def __init__(self, config: ProfilerConfig, tool_config: Optional[TorchProfilerToolConfig] = None):
43+
def __init__(
44+
self, config: ProfilerConfig, tool_config: Optional[TorchProfilerToolConfig] = None, save_file_prefix=None
45+
):
4446
# note : if we do not set use_profile, it will be set as None, so that all function will be skip
4547
if not config:
4648
config = ProfilerConfig(ranks=[], enable=False)
49+
50+
self.save_file_prefix = save_file_prefix
51+
4752
if not tool_config:
4853
assert not config.enable, "tool_config must be provided when profiler is enabled"
4954
self.prof = None
@@ -56,7 +61,8 @@ def __init__(self, config: ProfilerConfig, tool_config: Optional[TorchProfilerTo
5661
self.rank = torch.distributed.get_rank()
5762
# we need to validate the config before using the profiler
5863
self._validate()
59-
if self.rank in self.config.profile_ranks:
64+
65+
if self.rank in self.config.ranks or self.config.all_ranks:
6066
print(f"[Profiler] Profiler init for rank {self.rank}")
6167

6268
self.prof = torch.profiler.profile(
@@ -74,11 +80,24 @@ def __init__(self, config: ProfilerConfig, tool_config: Optional[TorchProfilerTo
7480
with_stack=True,
7581
)
7682

83+
def _trace_handler(self, prof):
84+
if not os.path.exists(self.config.save_path):
85+
os.makedirs(self.config.save_path)
86+
87+
save_file_name = f"prof_rank-{self.rank}.json.gz"
88+
if self.save_file_prefix is not None:
89+
save_file_name = self.save_file_prefix + "_" + save_file_name
90+
save_path = os.path.join(self.config.save_path, save_file_name)
91+
print(f"[Profiler] Saving trace to {save_path}")
92+
prof.export_chrome_trace(save_path)
93+
self.enable = False
94+
self.saved = True
95+
7796
def _validate(self):
7897
if self.enable:
79-
if self.config.profile_ranks is None:
98+
if self.config.ranks is None:
8099
print("[WARNING] Profile ranks is not set, default to rank 0")
81-
self.config.profile_ranks = [0]
100+
self.config.ranks = [0]
82101
assert self.tool_config.step_start >= 0, "[ERROR] Profile step start must be greater than 0"
83102
assert self.tool_config.step_end >= 0, "[ERROR] Profile step end must be greater than 0"
84103
assert self.tool_config.step_start < self.tool_config.step_end, (
@@ -99,18 +118,14 @@ def step(self):
99118

100119
def stop(self):
101120
if self.check():
121+
self.step()
102122
print(f"[Profiler] stopped for rank {self.rank}")
103123
self.prof.stop()
124+
self.save()
104125

105126
def save(self):
106-
if self.prof is not None and not self.saved:
107-
if not os.path.exists(self.config.save_path):
108-
os.makedirs(self.config.save_path)
109-
save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json"
110-
print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}")
111-
self.prof.export_chrome_trace(self.config.save_path + save_file_name)
112-
self.enable = False
113-
self.saved = True
127+
if self.prof is not None and not self.saved and self.tool_config.manual_save:
128+
self._trace_handler(prof=self.prof)
114129

115130
def stop_and_save(self):
116131
if self.check():
@@ -188,7 +203,10 @@ def __init__(
188203
):
189204
# Default config
190205
if not config:
191-
config = ProfilerConfig(ranks=[], enable=False)
206+
config = ProfilerConfig(ranks=[], enable=False, tool_config=None)
207+
208+
if tool_config is None:
209+
tool_config = config.tool_config
192210

193211
self._impl = None
194212
self._tool = getattr(config, "tool", None)

verl/utils/seqlen_balancing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool)
6868
Note:
6969
When equal_size=True, len(seqlen_list) must be divisible by k_partitions.
7070
"""
71+
7172
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
7273
class Set:
7374
def __init__(self) -> None:

verl/utils/torch_functional.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) ->
200200
return logprobs_labels
201201

202202

203-
def clip_by_value(
204-
x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor
205-
) -> torch.Tensor:
203+
def clip_by_value(x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor) -> torch.Tensor:
206204
"""Clip tensor values to a range defined by tensor bounds.
207205
208206
Extension of torch.clamp that supports tensor-valued min/max bounds
@@ -265,9 +263,7 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20
265263
return entropy
266264

267265

268-
def masked_sum(
269-
values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None
270-
) -> torch.Tensor:
266+
def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None) -> torch.Tensor:
271267
"""Compute sum of tensor values where mask is True.
272268
273269
NaN values outside the mask are replaced with zeros to prevent
@@ -389,9 +385,7 @@ def compute_grad_norm(model: nn.Module) -> float:
389385
return total_grad_square
390386

391387

392-
def broadcast_dict_tensor(
393-
tensors: dict[str, torch.Tensor] | TensorDict, src: int, group
394-
) -> None:
388+
def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None:
395389
"""Broadcast all tensors in a dictionary from source rank to all ranks.
396390
397391
Iterates over all tensors in the dictionary and broadcasts each one

verl/workers/config/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from verl.base_config import BaseConfig
2020
from verl.trainer.config import CheckpointConfig
2121

22+
from ...utils.profiler import ProfilerConfig
2223
from .model import HFModelConfig
2324
from .optimizer import OptimizerConfig
2425

@@ -273,3 +274,4 @@ class TrainingWorkerConfig(BaseConfig):
273274
engine_config: EngineConfig = None
274275
optimizer_config: OptimizerConfig = None
275276
checkpoint_config: CheckpointConfig = None
277+
profiler_config: ProfilerConfig = None

0 commit comments

Comments
 (0)