Open
Description
We've used the _get_clones utility a lot for cleaner instantiation of our TransformerDecoder
from a single layer. However, we shouldn't use it for cases that require random initialization and don't subsequently override their params with a state dict load (i.e. what we do for LoRA when we're not resuming from an intermediate checkpoint). The following script demonstrates why: the initialized values for cloned modules will be the same across layers, so if we use _get_clones
(and don't subsequently load in a weight to override the init values), all our layers have identical values.
import torch
from torchtune.modules.transformer import _get_clones
from torchtune.modules.peft import LoRALinear
def main():
loras_loop = [None] * 4
for i in range(4):
loras_loop[i] = LoRALinear(in_dim=16, out_dim=8, rank=4, alpha=1.0)
loras_cloned = _get_clones(
LoRALinear(in_dim=16, out_dim=8, rank=4, alpha=1.0), 4
)
loop_max_diff = torch.max(torch.abs(loras_loop[0].lora_a.weight - loras_loop[3].lora_a.weight))
cloned_max_diff = torch.max(torch.abs(loras_cloned[0].lora_a.weight - loras_cloned[3].lora_a.weight))
print(f"Max diff in for loop: {loop_max_diff}")
print(f"Max diff with _get_clones: {cloned_max_diff}")
if __name__ == "__main__":
main()
...
Max diff in for loop: 0.4612632691860199
Max diff with _get_clones: 0.0