|
| 1 | +# Copyright 2024 Bytedance Ltd. and/or its affiliates |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Test that apply_fsdp2's module selection handles peft-wrapped models. |
| 15 | +
|
| 16 | +peft wraps embed_tokens in a ModulesToSaveWrapper, so isinstance(module, nn.Embedding) |
| 17 | +fails. Without name-based matching, embed_tokens + lm_head land in the root FSDP unit, |
| 18 | +causing OOM from oversized allgather. These tests verify the module selection logic |
| 19 | +works for: (1) vanilla models, (2) peft-wrapped models, (3) tied embeddings. |
| 20 | +""" |
| 21 | + |
| 22 | +import unittest |
| 23 | +from types import SimpleNamespace |
| 24 | + |
| 25 | +import torch.nn as nn |
| 26 | + |
| 27 | +from verl.utils.fsdp_utils import _select_fsdp2_wrap_targets |
| 28 | + |
| 29 | + |
| 30 | +class MockDecoderLayer(nn.Module): |
| 31 | + """Simulates a transformer decoder layer (e.g. Qwen3DecoderLayer).""" |
| 32 | + |
| 33 | + def __init__(self, hidden_size=64): |
| 34 | + super().__init__() |
| 35 | + self.self_attn = nn.Linear(hidden_size, hidden_size) |
| 36 | + self.mlp = nn.Linear(hidden_size, hidden_size) |
| 37 | + |
| 38 | + |
| 39 | +class MockModulesToSaveWrapper(nn.Module): |
| 40 | + """Simulates peft's ModulesToSaveWrapper around nn.Embedding. |
| 41 | +
|
| 42 | + peft wraps modules listed in modules_to_save (like embed_tokens) in this wrapper, |
| 43 | + which breaks isinstance(module, nn.Embedding) checks. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, original_module): |
| 47 | + super().__init__() |
| 48 | + self.original_module = original_module |
| 49 | + self.weight = original_module.weight # peft exposes weight |
| 50 | + |
| 51 | + |
| 52 | +class MockCausalLM(nn.Module): |
| 53 | + """Simulates a causal LM with embed_tokens, decoder layers, and lm_head.""" |
| 54 | + |
| 55 | + _no_split_modules = ["MockDecoderLayer"] |
| 56 | + |
| 57 | + def __init__(self, vocab_size=1000, hidden_size=64, num_layers=2, tie_word_embeddings=False): |
| 58 | + super().__init__() |
| 59 | + self.config = SimpleNamespace(tie_word_embeddings=tie_word_embeddings) |
| 60 | + self.model = nn.Module() |
| 61 | + self.model.embed_tokens = nn.Embedding(vocab_size, hidden_size) |
| 62 | + self.model.layers = nn.ModuleList([MockDecoderLayer(hidden_size) for _ in range(num_layers)]) |
| 63 | + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
| 64 | + |
| 65 | + if tie_word_embeddings: |
| 66 | + self.lm_head.weight = self.model.embed_tokens.weight |
| 67 | + |
| 68 | + |
| 69 | +class TestFSDP2PeftWrapping(unittest.TestCase): |
| 70 | + """Test module selection in apply_fsdp2 for vanilla and peft-wrapped models.""" |
| 71 | + |
| 72 | + def _get_wrapped_names(self, model, cls_names): |
| 73 | + """Return names of modules selected for wrapping.""" |
| 74 | + selected = _select_fsdp2_wrap_targets(model, cls_names) |
| 75 | + # _select_fsdp2_wrap_targets returns module objects; map back to names |
| 76 | + module_to_name = {id(m): n for n, m in model.named_modules()} |
| 77 | + return [module_to_name[id(m)] for m in selected] |
| 78 | + |
| 79 | + def test_vanilla_model_wraps_layers_and_embedding(self): |
| 80 | + """Vanilla model (no peft): embed_tokens matched by isinstance, layers by class name.""" |
| 81 | + model = MockCausalLM(tie_word_embeddings=False) |
| 82 | + names = self._get_wrapped_names(model, ["MockDecoderLayer"]) |
| 83 | + |
| 84 | + self.assertIn("model.embed_tokens", names) |
| 85 | + self.assertIn("lm_head", names) |
| 86 | + self.assertTrue(any("layers.0" in n for n in names)) |
| 87 | + self.assertTrue(any("layers.1" in n for n in names)) |
| 88 | + |
| 89 | + def test_peft_wrapped_model_wraps_embed_tokens_by_name(self): |
| 90 | + """peft-wrapped model: embed_tokens fails isinstance but is matched by name.""" |
| 91 | + model = MockCausalLM(tie_word_embeddings=False) |
| 92 | + original_embed = model.model.embed_tokens |
| 93 | + model.model.embed_tokens = MockModulesToSaveWrapper(original_embed) |
| 94 | + |
| 95 | + names = self._get_wrapped_names(model, ["MockDecoderLayer"]) |
| 96 | + |
| 97 | + self.assertIn("model.embed_tokens", names) |
| 98 | + self.assertIn("lm_head", names) |
| 99 | + self.assertTrue(any("layers.0" in n for n in names)) |
| 100 | + |
| 101 | + def test_tied_embeddings_skips_name_based_wrapping(self): |
| 102 | + """With tie_word_embeddings=True, embed_tokens/lm_head are NOT wrapped separately.""" |
| 103 | + model = MockCausalLM(tie_word_embeddings=True) |
| 104 | + names = self._get_wrapped_names(model, ["MockDecoderLayer"]) |
| 105 | + |
| 106 | + self.assertNotIn("model.embed_tokens", names) |
| 107 | + self.assertNotIn("lm_head", names) |
| 108 | + self.assertTrue(any("layers.0" in n for n in names)) |
| 109 | + |
| 110 | + def test_peft_wrapped_tied_embeddings_skips_wrapping(self): |
| 111 | + """peft + tied embeddings: name-based matching is disabled, no wrapping.""" |
| 112 | + model = MockCausalLM(tie_word_embeddings=True) |
| 113 | + original_embed = model.model.embed_tokens |
| 114 | + model.model.embed_tokens = MockModulesToSaveWrapper(original_embed) |
| 115 | + |
| 116 | + names = self._get_wrapped_names(model, ["MockDecoderLayer"]) |
| 117 | + |
| 118 | + self.assertNotIn("model.embed_tokens", names) |
| 119 | + self.assertNotIn("lm_head", names) |
| 120 | + |
| 121 | + def test_no_duplicate_wrapping_for_vanilla_embedding(self): |
| 122 | + """Vanilla nn.Embedding should not be wrapped twice (by isinstance AND by name).""" |
| 123 | + model = MockCausalLM(tie_word_embeddings=False) |
| 124 | + names = self._get_wrapped_names(model, ["MockDecoderLayer"]) |
| 125 | + |
| 126 | + embed_count = sum(1 for n in names if n == "model.embed_tokens") |
| 127 | + self.assertEqual(embed_count, 1, f"embed_tokens wrapped {embed_count} times, expected 1") |
| 128 | + |
| 129 | + |
| 130 | +if __name__ == "__main__": |
| 131 | + unittest.main() |
0 commit comments