Skip to content

Commit b5dfaa2

Browse files
authored
[revert] revert the parallel state change (#1442)
1 parent e79dd7a commit b5dfaa2

File tree

19 files changed

+1664
-1420
lines changed

19 files changed

+1664
-1420
lines changed

examples/train_infer_mismatch_helper/mis.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import torch
44

5-
from slime.backends.training_utils.parallel import ParallelState
6-
75
# NOTE:
86
# - `compute_mis_weights` is a lightweight, standalone function that is useful to unit-test on CPU.
97
# - `compute_mis_weights_with_cp` depends on Megatron context-parallel utilities, which are heavy and may not be
@@ -318,7 +316,6 @@ def compute_mis_weights_with_cp(
318316
loss_masks: list[torch.Tensor],
319317
total_lengths: list[int],
320318
response_lengths: list[int],
321-
parallel_state: ParallelState,
322319
**kwargs: Any,
323320
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
324321
"""
@@ -335,17 +332,17 @@ def compute_mis_weights_with_cp(
335332
is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
336333
"""
337334
# Lazy import to avoid importing Megatron dependencies when only `compute_mis_weights` is used.
338-
from slime.backends.training_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp
335+
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp
339336

340337
# Gather cp slice from other cp ranks
341338
full_rollout_log_probs = [
342-
all_gather_with_cp(log_prob, total_length, response_length, parallel_state)
339+
all_gather_with_cp(log_prob, total_length, response_length)
343340
for log_prob, total_length, response_length in zip(
344341
rollout_log_probs, total_lengths, response_lengths, strict=False
345342
)
346343
]
347344
full_old_log_probs = [
348-
all_gather_with_cp(old_log_prob, total_length, response_length, parallel_state)
345+
all_gather_with_cp(old_log_prob, total_length, response_length)
349346
for old_log_prob, total_length, response_length in zip(
350347
train_log_probs, total_lengths, response_lengths, strict=False
351348
)
@@ -365,7 +362,7 @@ def slice_cp_and_concat(
365362
) -> torch.Tensor:
366363
values = [
367364
# TODO: A rename of this function?
368-
slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i], parallel_state)
365+
slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i])
369366
for i in range(len(values))
370367
]
371368
return torch.cat(values, dim=0)

scripts/models/qwen3-coder-30B-A3B-Instruct.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

slime/backends/fsdp_utils/__init__.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
11
import logging
22

3-
try:
4-
from torch.distributed.fsdp import fully_shard # noqa: F401
5-
6-
_FSDP_AVAILABLE = True
7-
except ImportError as e:
8-
logging.warning(f"FSDP backend dependencies not available: {e}")
9-
_FSDP_AVAILABLE = False
10-
11-
if _FSDP_AVAILABLE:
12-
from .actor import FSDPTrainRayActor
13-
from .arguments import load_fsdp_args
14-
else:
15-
16-
def _raise_import_error(*args, **kwargs):
17-
raise ImportError(
18-
"FSDP backend is not available. "
19-
"Please ensure PyTorch with FSDP2 support is installed. "
20-
"For installation instructions, refer to: https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html"
21-
)
22-
23-
FSDPTrainRayActor = _raise_import_error
24-
load_fsdp_args = _raise_import_error
3+
from .actor import FSDPTrainRayActor
4+
from .arguments import load_fsdp_args
255

266
__all__ = ["load_fsdp_args", "FSDPTrainRayActor"]
277

0 commit comments

Comments
 (0)