Skip to content

Commit 397d6ca

Browse files
authored
[Doc] Include type annotation for cp_utils and model_provider (#468)
1 parent e075848 commit 397d6ca

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

slime/backends/megatron_utils/cp_utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Callable, Union
22

33
import torch
44
import torch.distributed as dist
@@ -45,26 +45,26 @@ def get_logits_and_tokens_offset_with_cp(
4545

4646

4747
def get_sum_of_sample_mean(
48-
total_lengths,
49-
response_lengths,
50-
loss_masks,
48+
total_lengths: list[int],
49+
response_lengths: list[int],
50+
loss_masks: list[torch.Tensor],
5151
calculate_per_token_loss: bool = False,
52-
):
52+
) -> Callable[[torch.Tensor], torch.Tensor]:
5353
"""
5454
Calculate correct sample mean for CP
5555
"""
5656
cp_size = mpu.get_context_parallel_world_size()
5757
if cp_size == 1:
5858

59-
def sum_of_sample_mean(x: torch.Tensor):
59+
def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
6060
return sum(
6161
[
6262
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
6363
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)
6464
]
6565
)
6666

67-
def sum_of_token(x: torch.Tensor):
67+
def sum_of_token(x: torch.Tensor) -> torch.Tensor:
6868
return sum(
6969
[(x_i * loss_mask_i).sum() for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)]
7070
)
@@ -82,7 +82,7 @@ def sum_of_token(x: torch.Tensor):
8282
chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0))
8383
cp_chunk_lengths.append(chunked_loss_masks[i].size(0))
8484

85-
def sum_of_sample_mean(x):
85+
def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
8686
return sum(
8787
[
8888
(x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
@@ -92,7 +92,7 @@ def sum_of_sample_mean(x):
9292
]
9393
)
9494

95-
def sum_of_token(x: torch.Tensor):
95+
def sum_of_token(x: torch.Tensor) -> torch.Tensor:
9696
return sum(
9797
[
9898
(x_i * chunked_loss_mask).sum()
@@ -103,7 +103,7 @@ def sum_of_token(x: torch.Tensor):
103103
return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token
104104

105105

106-
def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int):
106+
def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor:
107107
"""
108108
Gather tensors across all ranks in the context parallel group.
109109
The first dimension of the output tensor will be the `response_length`.
@@ -122,7 +122,7 @@ def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length:
122122
chunk_1 = tensor[logits_offset[0][1] - logits_offset[0][0] :]
123123
assert chunk_1.shape[0] == logits_offset[1][1] - logits_offset[1][0]
124124

125-
def zero(len):
125+
def zero(len: int) -> torch.Tensor:
126126
return torch.zeros(
127127
[len] + list(tensor.shape[1:]),
128128
dtype=tensor.dtype,
@@ -155,7 +155,7 @@ def zero(len):
155155
return full_tensor
156156

157157

158-
def slice_with_cp(tokens: torch.Tensor, pad_value):
158+
def slice_with_cp(tokens: torch.Tensor, pad_value: int) -> torch.Tensor:
159159
cp_rank = mpu.get_context_parallel_rank()
160160
cp_size = mpu.get_context_parallel_world_size()
161161

@@ -172,7 +172,11 @@ def slice_with_cp(tokens: torch.Tensor, pad_value):
172172
return torch.cat([tokens[start_1:end_1], tokens[start_2:end_2]])
173173

174174

175-
def slice_log_prob_with_cp(log_prob: Union[list[float], torch.Tensor], total_length: int, response_length: int):
175+
def slice_log_prob_with_cp(
176+
log_prob: Union[list[float], torch.Tensor],
177+
total_length: int,
178+
response_length: int,
179+
) -> Union[list[float], torch.Tensor]:
176180
assert len(log_prob) == response_length
177181

178182
cp_size = mpu.get_context_parallel_world_size()

slime/backends/megatron_utils/model_provider.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Adapt from https://github.com/NVIDIA/Megatron-LM/blob/b1efb3c7126ef7615e8c333432d76e08038e17ff/pretrain_gpt.py
2+
import argparse
23
import inspect
34
from contextlib import nullcontext
4-
from typing import Optional
5+
from typing import Literal, Optional
56

67
import torch
78
from megatron.core import tensor_parallel
@@ -12,19 +13,20 @@
1213
get_gpt_layer_with_transformer_engine_spec,
1314
)
1415
from megatron.core.transformer.spec_utils import import_module
16+
from megatron.core.transformer.transformer_config import TransformerConfig
1517
from megatron.training.arguments import core_transformer_config_from_args
1618

1719

1820
# Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82
1921
class LinearForLastLayer(torch.nn.Linear):
2022
def __init__(
2123
self,
22-
input_size,
23-
output_size,
24+
input_size: int,
25+
output_size: int,
2426
*,
25-
config,
26-
bias=True,
27-
):
27+
config: TransformerConfig,
28+
bias: bool = True,
29+
) -> None:
2830
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
2931
self.sequence_parallel = config.sequence_parallel
3032
if self.sequence_parallel:
@@ -36,19 +38,24 @@ def __init__(
3638

3739
def forward(
3840
self,
39-
input_,
40-
weight=None,
41-
runtime_gather_output=None,
42-
):
41+
input_: torch.Tensor,
42+
weight: Optional[torch.Tensor] = None,
43+
runtime_gather_output: Optional[bool] = None,
44+
) -> tuple[torch.Tensor, None]:
4345
logits = super().forward(input_)
4446
logits = logits.float()
4547
if self.sequence_parallel:
4648
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
4749
return logits, None
4850

4951

50-
def get_model_provider_func(args, role: str = "actor"):
51-
def model_provider(pre_process=True, post_process=True, vp_stage: Optional[int] = None) -> GPTModel:
52+
def get_model_provider_func(
53+
args: argparse.Namespace,
54+
role: Literal["actor", "critic"] = "actor",
55+
):
56+
def model_provider(
57+
pre_process: bool = True, post_process: bool = True, vp_stage: Optional[int] = None
58+
) -> GPTModel:
5259
"""Builds the model.
5360
5461
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
@@ -87,7 +94,7 @@ def oom_observer(device, alloc, device_alloc, device_free):
8794
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
8895

8996
# Experimental loading arguments from yaml
90-
config = core_transformer_config_from_args(args)
97+
config: TransformerConfig = core_transformer_config_from_args(args)
9198

9299
if args.spec is not None:
93100
transformer_layer_spec = import_module(args.spec)
@@ -134,7 +141,7 @@ def oom_observer(device, alloc, device_alloc, device_free):
134141
# Check if fp8_model_init supports preserve_high_precision_init_val
135142
if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
136143
build_model_context_args["preserve_high_precision_init_val"] = True
137-
except:
144+
except Exception:
138145
raise RuntimeError(
139146
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
140147
)

0 commit comments

Comments
 (0)