Skip to content

Commit a4bb460

Browse files
committed
ci fixes
1 parent fb4a83c commit a4bb460

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

tests/special_sanity/check_device_api_usage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name
4848
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
4949
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
50+
"verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes
5051
]
5152

5253
# directory or file path must contain keyword "nccl"

tests/special_sanity/check_pr_title.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
pr_title = os.environ.get("PR_TITLE", "").strip()
2020

2121
# Define rules
22-
allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"]
22+
allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "trtllm", "rollout", "trainer"]
2323
allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"]
2424
allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"]
2525
allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg", "reward"]

verl/trainer/ppo/ray_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -929,12 +929,12 @@ def init_workers(self):
929929
else:
930930
rm_resource_pool = None
931931

932-
self.async_rollout_manager = AgentLoopManager(
933-
config=self.config,
934-
worker_group=self.actor_rollout_wg,
935-
rollout_resource_pool=actor_rollout_resource_pool,
936-
rm_resource_pool=rm_resource_pool,
937-
)
932+
self.async_rollout_manager = AgentLoopManager(
933+
config=self.config,
934+
worker_group=self.actor_rollout_wg,
935+
rollout_resource_pool=actor_rollout_resource_pool,
936+
rm_resource_pool=rm_resource_pool,
937+
)
938938

939939
def _save_checkpoint(self):
940940
from verl.utils.fs import local_mkdir_safe

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool
2727
from verl.utils.config import omega_conf_to_dataclass
28+
from verl.utils.device import is_cuda_available
2829
from verl.workers.config import HFModelConfig, RolloutConfig
2930
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
3031
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import TRTLLMAsyncRollout
@@ -61,7 +62,7 @@ def __init__(
6162
bundle_indices: list[list[int]] = None,
6263
):
6364
os.environ["TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL"] = "1"
64-
assert torch.cuda.is_available(), "TRTLLM http server should run on GPU node"
65+
assert is_cuda_available, "TRTLLM http server should run on GPU node"
6566

6667
self.config: RolloutConfig = omega_conf_to_dataclass(config)
6768
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)

0 commit comments

Comments
 (0)