Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/experimental/vla/workers/env/env_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_device_name,
)
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig
from verl.utils.profiler import PROFILER_TOOL_NAMES, DistProfiler, DistProfilerExtension, ProfilerConfig


def put_tensor_cpu(data_dict):
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, config: DictConfig):
# Initialize profiler
omega_profiler_config = config.train.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
3 changes: 2 additions & 1 deletion verl/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..device import is_npu_available
from ..import_utils import is_nvtx_available
from .config import build_sglang_profiler_args, build_vllm_profiler_args
from .config import PROFILER_TOOL_NAMES, build_sglang_profiler_args, build_vllm_profiler_args
from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig

Expand All @@ -29,6 +29,7 @@

__all__ = [
"GPUMemoryLogger",
"PROFILER_TOOL_NAMES",
"log_gpu_memory_usage",
"mark_start_range",
"mark_end_range",
Expand Down
3 changes: 3 additions & 0 deletions verl/utils/profiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from verl.base_config import BaseConfig

# Supported profiler tool names (shared constant)
PROFILER_TOOL_NAMES = frozenset({"npu", "nsys", "torch", "torch_memory"})


@dataclass
class NsightToolConfig(BaseConfig):
Expand Down
10 changes: 8 additions & 2 deletions verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from verl.utils.flops_counter import FlopsCounter
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.metric.utils import Metric
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage
from verl.utils.profiler import (
PROFILER_TOOL_NAMES,
DistProfiler,
DistProfilerExtension,
ProfilerConfig,
log_gpu_memory_usage,
)
from verl.utils.py_functional import append_to_dict
from verl.utils.tensordict_utils import maybe_fix_3d_position_ids
from verl.utils.torch_functional import allgather_dict_into_dict
Expand Down Expand Up @@ -441,7 +447,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
omega_profiler_config = config.ref.get("profiler", {})

profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
13 changes: 10 additions & 3 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@
from verl.utils.import_utils import import_external_libs
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler import (
PROFILER_TOOL_NAMES,
DistProfiler,
DistProfilerExtension,
ProfilerConfig,
log_gpu_memory_usage,
simple_timer,
)
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.utils.py_functional import convert_to_regular_types

Expand Down Expand Up @@ -225,7 +232,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down Expand Up @@ -1276,7 +1283,7 @@ def __init__(self, config: FSDPCriticConfig):
Worker.__init__(self)
omega_profiler_config = config.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
5 changes: 3 additions & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights
from verl.utils.profiler import (
PROFILER_TOOL_NAMES,
DistProfiler,
DistProfilerExtension,
GPUMemoryLogger,
Expand Down Expand Up @@ -330,7 +331,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down Expand Up @@ -996,7 +997,7 @@ def __init__(self, config: McoreCriticConfig):

omega_profiler_config = config.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down