Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/experimental/vla/workers/env/env_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_device_name,
)
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig
from verl.utils.profiler import DistProfiler, DistProfilerExtension, PROFILER_TOOL_NAMES, ProfilerConfig


def put_tensor_cpu(data_dict):
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, config: DictConfig):
# Initialize profiler
omega_profiler_config = config.train.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
3 changes: 2 additions & 1 deletion verl/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..device import is_npu_available
from ..import_utils import is_nvtx_available
from .config import build_sglang_profiler_args, build_vllm_profiler_args
from .config import PROFILER_TOOL_NAMES, build_sglang_profiler_args, build_vllm_profiler_args
from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig

Expand All @@ -29,6 +29,7 @@

__all__ = [
"GPUMemoryLogger",
"PROFILER_TOOL_NAMES",
"log_gpu_memory_usage",
"mark_start_range",
"mark_end_range",
Expand Down
3 changes: 3 additions & 0 deletions verl/utils/profiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from verl.base_config import BaseConfig

# Supported profiler tool names (shared constant)
PROFILER_TOOL_NAMES: frozenset[str] = frozenset({"npu", "nsys", "torch", "torch_memory"})


@dataclass
class NsightToolConfig(BaseConfig):
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from verl.utils.flops_counter import FlopsCounter
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.metric.utils import Metric
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage
from verl.utils.profiler import DistProfiler, DistProfilerExtension, PROFILER_TOOL_NAMES, ProfilerConfig, log_gpu_memory_usage
from verl.utils.py_functional import append_to_dict
from verl.utils.tensordict_utils import maybe_fix_3d_position_ids
from verl.utils.torch_functional import allgather_dict_into_dict
Expand Down Expand Up @@ -441,7 +441,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
omega_profiler_config = config.ref.get("profiler", {})

profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
6 changes: 3 additions & 3 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from verl.utils.import_utils import import_external_libs
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler import DistProfiler, DistProfilerExtension, PROFILER_TOOL_NAMES, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.utils.py_functional import convert_to_regular_types

Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down Expand Up @@ -1276,7 +1276,7 @@ def __init__(self, config: FSDPCriticConfig):
Worker.__init__(self)
omega_profiler_config = config.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down
5 changes: 3 additions & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
DistProfiler,
DistProfilerExtension,
GPUMemoryLogger,
PROFILER_TOOL_NAMES,
ProfilerConfig,
log_gpu_memory_usage,
simple_timer,
Expand Down Expand Up @@ -330,7 +331,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down Expand Up @@ -996,7 +997,7 @@ def __init__(self, config: McoreCriticConfig):

omega_profiler_config = config.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
if omega_profiler_config.get("tool", None) in PROFILER_TOOL_NAMES:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
Expand Down