Skip to content

Commit 570b9e8

Browse files
committed
formate
1 parent 982fc3b commit 570b9e8

13 files changed

+68
-20
lines changed

tests/plugin_contracts/test_plugin_buffer_filter_contracts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
if "transformers" not in sys.modules:
2121
mod = types.ModuleType("transformers")
2222
mod.AutoTokenizer = type("AutoTokenizer", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: object())})
23-
mod.AutoProcessor = type("AutoProcessor", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))})
23+
mod.AutoProcessor = type(
24+
"AutoProcessor",
25+
(),
26+
{"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))},
27+
)
2428
mod.PreTrainedTokenizerBase = type("PreTrainedTokenizerBase", (), {})
2529
mod.ProcessorMixin = type("ProcessorMixin", (), {})
2630
sys.modules["transformers"] = mod

tests/plugin_contracts/test_plugin_custom_convert_samples_to_train_data_contracts.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,15 @@ def assert_runtime_callsite_is_stable() -> None:
7474
def assert_custom_convert_output_matches_expected(fn) -> None:
7575
samples = [make_sample(0, 0.5), make_sample(1, 1.5)]
7676
train_data = fn(type("Args", (), {})(), samples)
77-
required_keys = {"tokens", "response_lengths", "rewards", "raw_reward", "truncated", "sample_indices", "loss_masks"}
77+
required_keys = {
78+
"tokens",
79+
"response_lengths",
80+
"rewards",
81+
"raw_reward",
82+
"truncated",
83+
"sample_indices",
84+
"loss_masks",
85+
}
7886
assert isinstance(train_data, dict)
7987
assert required_keys <= set(train_data)
8088
assert all(len(train_data[key]) == len(samples) for key in required_keys)

tests/plugin_contracts/test_plugin_custom_eval_rollout_log_contracts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
NUM_GPUS = 0
2222
ENV_PREFIX = "SLIME_CONTRACT_"
23-
REFERENCE_CUSTOM_EVAL_ROLLOUT_LOG_PATH = "plugin_contracts.test_plugin_custom_eval_rollout_log_contracts.reference_custom_eval_rollout_log"
23+
REFERENCE_CUSTOM_EVAL_ROLLOUT_LOG_PATH = (
24+
"plugin_contracts.test_plugin_custom_eval_rollout_log_contracts.reference_custom_eval_rollout_log"
25+
)
2426

2527
from slime.utils.types import Sample
2628

@@ -38,7 +40,9 @@ def run_contract_test_file() -> None:
3840
parser.add_argument("--custom-eval-rollout-log-function-path", default=None)
3941
args, remaining = parser.parse_known_args()
4042
if args.custom_eval_rollout_log_function_path:
41-
os.environ[contract_env_name("CUSTOM_EVAL_ROLLOUT_LOG_FUNCTION_PATH")] = args.custom_eval_rollout_log_function_path
43+
os.environ[contract_env_name("CUSTOM_EVAL_ROLLOUT_LOG_FUNCTION_PATH")] = (
44+
args.custom_eval_rollout_log_function_path
45+
)
4246
raise SystemExit(pytest.main([__file__, *remaining]))
4347

4448

tests/plugin_contracts/test_plugin_custom_reward_post_process_contracts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
NUM_GPUS = 0
2121
ENV_PREFIX = "SLIME_CONTRACT_"
22-
REFERENCE_CUSTOM_REWARD_POST_PROCESS_PATH = "plugin_contracts.test_plugin_custom_reward_post_process_contracts.reference_reward_post_process"
22+
REFERENCE_CUSTOM_REWARD_POST_PROCESS_PATH = (
23+
"plugin_contracts.test_plugin_custom_reward_post_process_contracts.reference_reward_post_process"
24+
)
2325

2426
from slime.utils.types import Sample
2527

@@ -72,9 +74,7 @@ def test_custom_reward_post_process_callsite_is_stable():
7274
def test_custom_reward_post_process_path_aligns_with_expected_format():
7375
from slime.utils.misc import load_function
7476

75-
fn = load_function(
76-
get_contract_path("CUSTOM_REWARD_POST_PROCESS_PATH", REFERENCE_CUSTOM_REWARD_POST_PROCESS_PATH)
77-
)
77+
fn = load_function(get_contract_path("CUSTOM_REWARD_POST_PROCESS_PATH", REFERENCE_CUSTOM_REWARD_POST_PROCESS_PATH))
7878
assert_custom_reward_post_process_output_matches_expected(fn)
7979

8080

tests/plugin_contracts/test_plugin_custom_rm_contracts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ async def test_custom_rm_default_batched_branch_is_stable():
9696
async def test_custom_rm_path_aligns_with_single_sample_format():
9797
rm_fn = importlib.import_module("plugin_contracts.test_plugin_custom_rm_contracts").reference_single_rm
9898
assert_single_rm_signature_matches_expected(rm_fn)
99-
reward = await async_rm(make_args(custom_rm_path=get_contract_path("CUSTOM_RM_PATH", REFERENCE_SINGLE_RM_PATH)), make_sample(3))
99+
reward = await async_rm(
100+
make_args(custom_rm_path=get_contract_path("CUSTOM_RM_PATH", REFERENCE_SINGLE_RM_PATH)), make_sample(3)
101+
)
100102
assert isinstance(reward, (int, float))
101103

102104

tests/plugin_contracts/test_plugin_custom_rollout_log_contracts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
NUM_GPUS = 0
2222
ENV_PREFIX = "SLIME_CONTRACT_"
23-
REFERENCE_CUSTOM_ROLLOUT_LOG_PATH = "plugin_contracts.test_plugin_custom_rollout_log_contracts.reference_custom_rollout_log"
23+
REFERENCE_CUSTOM_ROLLOUT_LOG_PATH = (
24+
"plugin_contracts.test_plugin_custom_rollout_log_contracts.reference_custom_rollout_log"
25+
)
2426

2527
from slime.utils.types import Sample
2628

tests/plugin_contracts/test_plugin_data_source_contracts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
if "transformers" not in sys.modules:
2121
mod = types.ModuleType("transformers")
2222
mod.AutoTokenizer = type("AutoTokenizer", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: object())})
23-
mod.AutoProcessor = type("AutoProcessor", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))})
23+
mod.AutoProcessor = type(
24+
"AutoProcessor",
25+
(),
26+
{"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))},
27+
)
2428
mod.PreTrainedTokenizerBase = type("PreTrainedTokenizerBase", (), {})
2529
mod.ProcessorMixin = type("ProcessorMixin", (), {})
2630
sys.modules["transformers"] = mod

tests/plugin_contracts/test_plugin_eval_function_contracts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
if "transformers" not in sys.modules:
2525
mod = types.ModuleType("transformers")
2626
mod.AutoTokenizer = type("AutoTokenizer", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: object())})
27-
mod.AutoProcessor = type("AutoProcessor", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))})
27+
mod.AutoProcessor = type(
28+
"AutoProcessor",
29+
(),
30+
{"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))},
31+
)
2832
mod.PreTrainedTokenizerBase = type("PreTrainedTokenizerBase", (), {})
2933
mod.ProcessorMixin = type("ProcessorMixin", (), {})
3034
sys.modules["transformers"] = mod

tests/plugin_contracts/test_plugin_generate_contracts.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,32 @@
1717
NUM_GPUS = 0
1818
ENV_PREFIX = "SLIME_CONTRACT_"
1919
REFERENCE_CUSTOM_GENERATE_PATH = "plugin_contracts.test_plugin_generate_contracts.custom_generate"
20-
REFERENCE_CUSTOM_GENERATE_WITH_EVAL_PATH = "plugin_contracts.test_plugin_generate_contracts.custom_generate_with_evaluation"
20+
REFERENCE_CUSTOM_GENERATE_WITH_EVAL_PATH = (
21+
"plugin_contracts.test_plugin_generate_contracts.custom_generate_with_evaluation"
22+
)
2123

2224

2325
def install_stubs() -> None:
2426
if "ray" not in sys.modules:
2527
ray_mod = types.ModuleType("ray")
26-
ray_mod._private = types.SimpleNamespace(services=types.SimpleNamespace(get_node_ip_address=lambda: "127.0.0.1"))
28+
ray_mod._private = types.SimpleNamespace(
29+
services=types.SimpleNamespace(get_node_ip_address=lambda: "127.0.0.1")
30+
)
2731
sys.modules["ray"] = ray_mod
2832
if "sglang_router" not in sys.modules:
2933
mod = types.ModuleType("sglang_router")
3034
mod.__version__ = "0.2.3"
3135
sys.modules["sglang_router"] = mod
3236
if "transformers" not in sys.modules:
3337
mod = types.ModuleType("transformers")
34-
mod.AutoTokenizer = type("AutoTokenizer", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: object())})
35-
mod.AutoProcessor = type("AutoProcessor", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))})
38+
mod.AutoTokenizer = type(
39+
"AutoTokenizer", (), {"from_pretrained": staticmethod(lambda *args, **kwargs: object())}
40+
)
41+
mod.AutoProcessor = type(
42+
"AutoProcessor",
43+
(),
44+
{"from_pretrained": staticmethod(lambda *args, **kwargs: (_ for _ in ()).throw(OSError()))},
45+
)
3646
mod.PreTrainedTokenizerBase = type("PreTrainedTokenizerBase", (), {})
3747
mod.ProcessorMixin = type("ProcessorMixin", (), {})
3848
sys.modules["transformers"] = mod

tests/plugin_contracts/test_plugin_rollout_all_samples_process_contracts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
NUM_GPUS = 0
1616
ENV_PREFIX = "SLIME_CONTRACT_"
17-
REFERENCE_ROLLOUT_ALL_SAMPLES_PROCESS_PATH = "plugin_contracts.test_plugin_rollout_all_samples_process_contracts.reference_rollout_all_samples_process"
17+
REFERENCE_ROLLOUT_ALL_SAMPLES_PROCESS_PATH = (
18+
"plugin_contracts.test_plugin_rollout_all_samples_process_contracts.reference_rollout_all_samples_process"
19+
)
1820

1921
from slime.utils.misc import load_function
2022
from slime.utils.types import Sample

0 commit comments

Comments
 (0)