Skip to content

Commit 4a6f13e

Browse files
authored
[trainer] fix checkpoint tracker (#467)
1 parent a044aea commit 4a6f13e

4 files changed

Lines changed: 24 additions & 21 deletions

File tree

tests/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def test_find_latest_ckpt(save_checkpoint_path):
3535
with open(os.path.join(save_checkpoint_path, CHECKPOINT_TRACKER), "w") as f:
3636
json.dump({"last_global_step": 10}, f, ensure_ascii=False, indent=2)
3737

38-
assert find_latest_ckpt(save_checkpoint_path) is None
38+
assert find_latest_ckpt(save_checkpoint_path)[0] is None
3939
os.makedirs(os.path.join(save_checkpoint_path, "global_step_10"), exist_ok=True)
40-
assert find_latest_ckpt(save_checkpoint_path) == os.path.join(save_checkpoint_path, "global_step_10")
40+
assert find_latest_ckpt(save_checkpoint_path)[0] == os.path.join(save_checkpoint_path, "global_step_10")
4141

4242

4343
def test_remove_obsolete_ckpt(save_checkpoint_path):

verl/models/transformers/flash_attention_utils.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
import torch.distributed as dist
23+
import torch.nn.functional as F
2324
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
2425
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
2526

@@ -43,19 +44,14 @@
4344
def prepare_fa2_from_position_ids(
4445
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
4546
):
46-
query = query.view(-1, query.size(-2), query.size(-1))
47+
assert position_ids.ndim == 2 # (batch_size, seq_length)
48+
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
4749
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
4850
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49-
position_ids = position_ids.flatten()
50-
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51-
cu_seqlens = torch.cat(
52-
(
53-
indices_q[position_ids == 0],
54-
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55-
)
56-
)
51+
position_ids = position_ids.view(-1)
52+
cu_seqlens = F.pad((position_ids == 0).nonzero().view(-1), (0, 1), value=position_ids.size())
5753
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58-
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
54+
return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))
5955

6056

6157
def _custom_flash_attention_forward(
@@ -102,7 +98,7 @@ def _custom_flash_attention_forward(
10298

10399
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
104100
batch_size = query_states.size(0)
105-
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
101+
query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
106102
query_states, key_states, value_states, position_ids
107103
)
108104
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
@@ -162,16 +158,18 @@ def flash_attention_forward(
162158
key = key.transpose(1, 2)
163159
value = value.transpose(1, 2)
164160

165-
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
166-
kwargs.pop("is_causal", None)
161+
# FA2 uses the kwargs value if explicitly passed, otherwise it uses the module attribute
162+
is_causal = kwargs.pop("is_causal", None)
163+
if is_causal is None:
164+
is_causal = getattr(module, "is_causal", True)
167165

168166
attn_output = _custom_flash_attention_forward(
169167
query,
170168
key,
171169
value,
172170
attention_mask,
173171
query_length=q_len,
174-
is_causal=module.is_causal,
172+
is_causal=is_causal,
175173
dropout=dropout,
176174
softmax_scale=scaling,
177175
sliding_window=sliding_window,

verl/trainer/ray_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,10 @@ def _load_checkpoint(self) -> None:
343343
if self.config.trainer.load_checkpoint_path is not None:
344344
load_checkpoint_path = self.config.trainer.load_checkpoint_path
345345
elif self.config.trainer.find_last_checkpoint:
346-
load_checkpoint_path = find_latest_ckpt(self.config.trainer.save_checkpoint_path)
346+
load_checkpoint_path, tracker_info = find_latest_ckpt(self.config.trainer.save_checkpoint_path)
347+
if tracker_info is not None:
348+
self.best_val_reward_score = tracker_info.get("best_val_reward_score", 0.0)
349+
self.best_global_step = tracker_info.get("best_global_step", 0)
347350
else:
348351
load_checkpoint_path = None
349352

verl/utils/checkpoint/checkpoint_manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,26 @@ def get_checkpoint_tracker_filename(root_path: str) -> str:
115115
return os.path.join(root_path, CHECKPOINT_TRACKER)
116116

117117

118-
def find_latest_ckpt(path: str, directory_format: str = "global_step_{}") -> Optional[str]:
118+
def find_latest_ckpt(
119+
path: str, directory_format: str = "global_step_{}"
120+
) -> tuple[Optional[str], Optional[dict[str, Any]]]:
119121
"""
120122
Find the latest checkpoint in the save path.
121123
"""
122124
tracker_file = get_checkpoint_tracker_filename(path)
123125
if not os.path.exists(tracker_file):
124-
return None
126+
return None, None
125127

126128
with open(tracker_file, "rb") as f:
127129
checkpointer_tracker_info = json.load(f)
128130

129131
ckpt_path = os.path.join(path, directory_format.format(checkpointer_tracker_info["last_global_step"]))
130132
if not os.path.exists(ckpt_path):
131133
print(f"Checkpoint does not exist: {ckpt_path}")
132-
return None
134+
return None, None
133135

134136
print(f"Found latest checkpoint: {ckpt_path}, will resume from it. Turn off `find_last_checkpoint` to disable it.")
135-
return ckpt_path
137+
return ckpt_path, checkpointer_tracker_info
136138

137139

138140
def remove_obsolete_ckpt(

0 commit comments

Comments
 (0)