Skip to content

Don't use _get_clones #2270

Open
Open
@ebsmothers

Description

We've used the _get_clones utility a lot for cleaner instantiation of our TransformerDecoder from a single layer. However, we shouldn't use it for cases that require random initialization and don't subsequently override their params with a state dict load (i.e. what we do for LoRA when we're not resuming from an intermediate checkpoint). The following script demonstrates why: the initialized values for cloned modules will be the same across layers, so if we use _get_clones (and don't subsequently load in a weight to override the init values), all our layers have identical values.

import torch
from torchtune.modules.transformer import _get_clones
from torchtune.modules.peft import LoRALinear

def main():
    loras_loop = [None] * 4
    for i in range(4):
        loras_loop[i] = LoRALinear(in_dim=16, out_dim=8, rank=4, alpha=1.0)

    loras_cloned = _get_clones(
        LoRALinear(in_dim=16, out_dim=8, rank=4, alpha=1.0), 4
    )

    loop_max_diff = torch.max(torch.abs(loras_loop[0].lora_a.weight - loras_loop[3].lora_a.weight))
    cloned_max_diff = torch.max(torch.abs(loras_cloned[0].lora_a.weight - loras_cloned[3].lora_a.weight))

    print(f"Max diff in for loop: {loop_max_diff}")
    print(f"Max diff with _get_clones: {cloned_max_diff}")

if __name__ == "__main__":
    main()


...

Max diff in for loop: 0.4612632691860199
Max diff with _get_clones: 0.0

Metadata

Assignees

No one assigned

    Labels

    best practiceThings we should be doing but aren'tcommunity help wantedWe would love the community's help completing this issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions