Skip to content

Commit 0114625

Browse files
committed
fix: fix test
1 parent 8ae8ca9 commit 0114625

File tree

4 files changed

+41
-29
lines changed

4 files changed

+41
-29
lines changed

tests/workers/rollout/test_sglang_async_rollout_sf_tools.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
3535
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
3636

37-
sandbox_url = ""
37+
sandbox_url = "https://sd04qmtd8e6v9i08l9l00.apigateway-cn-beijing.volceapi.com/run_code"
3838

3939

4040
def get_sandbox_fusion_data():
@@ -185,27 +185,28 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo
185185
req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)
186186
assert len(req_list) == 1
187187
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
188-
assert req_list[0].tools == [
189-
OpenAIFunctionToolSchema(
190-
type="function",
191-
function=OpenAIFunctionSchema(
192-
name="code_interpreter",
193-
description="A tool for executing code.",
194-
parameters=OpenAIFunctionParametersSchema(
195-
type="object",
196-
properties={
197-
"code": OpenAIFunctionPropertySchema(
198-
type="string",
199-
description="The code to execute.",
200-
enum=None,
201-
)
202-
},
203-
required=["code"],
204-
),
205-
strict=False,
188+
assert len(req_list[0].tools) == 1
189+
print("------------")
190+
print(type(req_list[0].tools[0]))
191+
assert req_list[0].tools[0] == OpenAIFunctionToolSchema(
192+
type="function",
193+
function=OpenAIFunctionSchema(
194+
name="code_interpreter",
195+
description="A tool for executing code.",
196+
parameters=OpenAIFunctionParametersSchema(
197+
type="object",
198+
properties={
199+
"code": OpenAIFunctionPropertySchema(
200+
type="string",
201+
description="The code to execute.",
202+
enum=None,
203+
)
204+
},
205+
required=["code"],
206206
),
207-
)
208-
]
207+
strict=False,
208+
),
209+
)
209210

210211
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
211212
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)

verl/tools/base_tool.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Optional, Tuple
15+
from typing import Any, Optional, Protocol, Tuple, runtime_checkable
1616
from uuid import uuid4
1717

1818
from .schemas import OpenAIFunctionToolSchema
@@ -84,3 +84,9 @@ async def release(self, instance_id: str, **kwargs) -> None:
8484
instance_id: The instance id of the tool.
8585
"""
8686
pass
87+
88+
89+
@runtime_checkable
90+
class DatasetIrrelevantTool(Protocol):
91+
def dataset_irrelevant(self) -> bool:
92+
return False

verl/tools/sandbox_fusion_tools.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case
2828

29-
from .base_tool import BaseTool
29+
from .base_tool import BaseTool, DatasetIrrelevantTool
3030
from .schemas import OpenAIFunctionToolSchema
3131

3232
logger = logging.getLogger(__name__)
@@ -93,7 +93,7 @@ def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_li
9393
# return ray.util.multiprocessing.Pool(processes=num_workers)
9494

9595

96-
class SandboxFusionTool(BaseTool):
96+
class SandboxFusionTool(BaseTool, DatasetIrrelevantTool):
9797
"""A tool for executing the code using sanbox fusion image.
9898
9999
- `to_openai_function_tool_schema`: return the tool schema in OpenAI format.
@@ -141,6 +141,9 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
141141
def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
142142
return self.tool_schema
143143

144+
def dataset_irrelevant(self):
145+
return True
146+
144147
async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:
145148
if instance_id is None:
146149
instance_id = str(uuid4())

verl/workers/rollout/sglang_rollout/async_sglang_rollout.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from verl import DataProto
4141
from verl.third_party.sglang import parallel_state as sglang_ps
42-
from verl.tools.base_tool import BaseTool
42+
from verl.tools.base_tool import BaseTool, DatasetIrrelevantTool
4343
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall
4444
from verl.utils.debug import GPUMemoryLogger
4545
from verl.utils.model import compute_position_id_with_mask
@@ -121,7 +121,7 @@ def _init_distributed_env(self, device_mesh_cpu, **kwargs):
121121
os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
122122
os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
123123
train_tp = kwargs.get("train_tp", None)
124-
num_tp_per_train_tp = train_tp // tensor_parallel_size
124+
num_tp_per_train_tp = train_tp // self.tensor_parallel_size
125125
sglang_ps.initialize_parallel_state(
126126
tensor_model_parallel_size=self.tensor_parallel_size,
127127
num_tp_per_train_tp=num_tp_per_train_tp,
@@ -555,7 +555,7 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool):
555555
return _req
556556

557557
async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict:
558-
generation_prompt = _req.get_generation_prompt(self.tokenizer)
558+
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer)
559559
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
560560
if not do_sample:
561561
kwargs = dict(
@@ -737,8 +737,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
737737
_tools_kwargs[k] = data_tools_kwargs[k]
738738
# add for dataset-irrelevant tools
739739
for tool_key in self._tool_map.keys():
740-
_tools_kwargs[tool_key] = {}
741-
_tool_schemas.append(self._tool_map[tool_key].get_openai_tool_schema())
740+
# TODO: redesign this logic
741+
if tool_key not in _tools_kwargs and isinstance(self._tool_map[tool_key], DatasetIrrelevantTool) and self._tool_map[tool_key].dataset_irrelevant():
742+
_tools_kwargs[tool_key] = {}
743+
_tool_schemas.append(self._tool_map[tool_key].get_openai_tool_schema())
742744
prompt_with_chat_template = self.tokenizer.apply_chat_template(
743745
conversation=raw_prompt,
744746
tools=[tool.model_dump() for tool in _tool_schemas],

0 commit comments

Comments
 (0)