Skip to content

Commit 95e17d5

Browse files
authored
[fsdp] fix: wrap embed_tokens/lm_head by name for peft models (#5516)
1 parent 339cad0 commit 95e17d5

File tree

2 files changed

+156
-6
lines changed

2 files changed

+156
-6
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test that apply_fsdp2's module selection handles peft-wrapped models.
15+
16+
peft wraps embed_tokens in a ModulesToSaveWrapper, so isinstance(module, nn.Embedding)
17+
fails. Without name-based matching, embed_tokens + lm_head land in the root FSDP unit,
18+
causing OOM from oversized allgather. These tests verify the module selection logic
19+
works for: (1) vanilla models, (2) peft-wrapped models, (3) tied embeddings.
20+
"""
21+
22+
import unittest
23+
from types import SimpleNamespace
24+
25+
import torch.nn as nn
26+
27+
from verl.utils.fsdp_utils import _select_fsdp2_wrap_targets
28+
29+
30+
class MockDecoderLayer(nn.Module):
31+
"""Simulates a transformer decoder layer (e.g. Qwen3DecoderLayer)."""
32+
33+
def __init__(self, hidden_size=64):
34+
super().__init__()
35+
self.self_attn = nn.Linear(hidden_size, hidden_size)
36+
self.mlp = nn.Linear(hidden_size, hidden_size)
37+
38+
39+
class MockModulesToSaveWrapper(nn.Module):
40+
"""Simulates peft's ModulesToSaveWrapper around nn.Embedding.
41+
42+
peft wraps modules listed in modules_to_save (like embed_tokens) in this wrapper,
43+
which breaks isinstance(module, nn.Embedding) checks.
44+
"""
45+
46+
def __init__(self, original_module):
47+
super().__init__()
48+
self.original_module = original_module
49+
self.weight = original_module.weight # peft exposes weight
50+
51+
52+
class MockCausalLM(nn.Module):
53+
"""Simulates a causal LM with embed_tokens, decoder layers, and lm_head."""
54+
55+
_no_split_modules = ["MockDecoderLayer"]
56+
57+
def __init__(self, vocab_size=1000, hidden_size=64, num_layers=2, tie_word_embeddings=False):
58+
super().__init__()
59+
self.config = SimpleNamespace(tie_word_embeddings=tie_word_embeddings)
60+
self.model = nn.Module()
61+
self.model.embed_tokens = nn.Embedding(vocab_size, hidden_size)
62+
self.model.layers = nn.ModuleList([MockDecoderLayer(hidden_size) for _ in range(num_layers)])
63+
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
64+
65+
if tie_word_embeddings:
66+
self.lm_head.weight = self.model.embed_tokens.weight
67+
68+
69+
class TestFSDP2PeftWrapping(unittest.TestCase):
70+
"""Test module selection in apply_fsdp2 for vanilla and peft-wrapped models."""
71+
72+
def _get_wrapped_names(self, model, cls_names):
73+
"""Return names of modules selected for wrapping."""
74+
selected = _select_fsdp2_wrap_targets(model, cls_names)
75+
# _select_fsdp2_wrap_targets returns module objects; map back to names
76+
module_to_name = {id(m): n for n, m in model.named_modules()}
77+
return [module_to_name[id(m)] for m in selected]
78+
79+
def test_vanilla_model_wraps_layers_and_embedding(self):
80+
"""Vanilla model (no peft): embed_tokens matched by isinstance, layers by class name."""
81+
model = MockCausalLM(tie_word_embeddings=False)
82+
names = self._get_wrapped_names(model, ["MockDecoderLayer"])
83+
84+
self.assertIn("model.embed_tokens", names)
85+
self.assertIn("lm_head", names)
86+
self.assertTrue(any("layers.0" in n for n in names))
87+
self.assertTrue(any("layers.1" in n for n in names))
88+
89+
def test_peft_wrapped_model_wraps_embed_tokens_by_name(self):
90+
"""peft-wrapped model: embed_tokens fails isinstance but is matched by name."""
91+
model = MockCausalLM(tie_word_embeddings=False)
92+
original_embed = model.model.embed_tokens
93+
model.model.embed_tokens = MockModulesToSaveWrapper(original_embed)
94+
95+
names = self._get_wrapped_names(model, ["MockDecoderLayer"])
96+
97+
self.assertIn("model.embed_tokens", names)
98+
self.assertIn("lm_head", names)
99+
self.assertTrue(any("layers.0" in n for n in names))
100+
101+
def test_tied_embeddings_skips_name_based_wrapping(self):
102+
"""With tie_word_embeddings=True, embed_tokens/lm_head are NOT wrapped separately."""
103+
model = MockCausalLM(tie_word_embeddings=True)
104+
names = self._get_wrapped_names(model, ["MockDecoderLayer"])
105+
106+
self.assertNotIn("model.embed_tokens", names)
107+
self.assertNotIn("lm_head", names)
108+
self.assertTrue(any("layers.0" in n for n in names))
109+
110+
def test_peft_wrapped_tied_embeddings_skips_wrapping(self):
111+
"""peft + tied embeddings: name-based matching is disabled, no wrapping."""
112+
model = MockCausalLM(tie_word_embeddings=True)
113+
original_embed = model.model.embed_tokens
114+
model.model.embed_tokens = MockModulesToSaveWrapper(original_embed)
115+
116+
names = self._get_wrapped_names(model, ["MockDecoderLayer"])
117+
118+
self.assertNotIn("model.embed_tokens", names)
119+
self.assertNotIn("lm_head", names)
120+
121+
def test_no_duplicate_wrapping_for_vanilla_embedding(self):
122+
"""Vanilla nn.Embedding should not be wrapped twice (by isinstance AND by name)."""
123+
model = MockCausalLM(tie_word_embeddings=False)
124+
names = self._get_wrapped_names(model, ["MockDecoderLayer"])
125+
126+
embed_count = sum(1 for n in names if n == "model.embed_tokens")
127+
self.assertEqual(embed_count, 1, f"embed_tokens wrapped {embed_count} times, expected 1")
128+
129+
130+
if __name__ == "__main__":
131+
unittest.main()

verl/utils/fsdp_utils.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,30 @@ class FSDPModuleABC(ABC, orig_fsdp_module):
507507
fully_shard_module.FSDPModule = orig_fsdp_module
508508

509509

510+
def _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap):
511+
"""Select modules to wrap individually with fully_shard in FSDP2.
512+
513+
Matches transformer layers by class name, and embed_tokens/lm_head by name
514+
(with isinstance fallback). Name-based matching is needed because peft wraps
515+
embed_tokens in ModulesToSaveWrapper, breaking isinstance(module, nn.Embedding).
516+
When tie_word_embeddings is True, embed_tokens and lm_head share weights and
517+
must not be wrapped separately.
518+
"""
519+
_tie = getattr(model.config, "tie_word_embeddings", False)
520+
_wrap_by_name = set() if _tie else {"embed_tokens", "lm_head"}
521+
522+
modules = []
523+
for name, module in model.named_modules():
524+
leaf_name = name.rsplit(".", 1)[-1] if "." in name else name
525+
if (
526+
module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap
527+
or (isinstance(module, nn.Embedding) and not _tie)
528+
or (leaf_name in _wrap_by_name and hasattr(module, "weight"))
529+
):
530+
modules.append(module)
531+
return modules
532+
533+
510534
def apply_fsdp2(model, fsdp_kwargs, config):
511535
"""model: AutoModelForCausalLM"""
512536
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
@@ -521,12 +545,7 @@ def apply_fsdp2(model, fsdp_kwargs, config):
521545

522546
assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None
523547

524-
modules = []
525-
for name, module in model.named_modules():
526-
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
527-
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
528-
):
529-
modules.append(module)
548+
modules = _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap)
530549

531550
for idx, module in enumerate(modules):
532551
# if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:

0 commit comments

Comments
 (0)