Skip to content

Hand-derived memory-efficient super lazy PyTorch VJPs for training LLMs on laptop, all using one op (bundled scaled matmuls).

License

Notifications You must be signed in to change notification settings

HMUNACHI/super-lazy-autograd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Alt text

LicenseLinkedIn Twitter

Author: Henry Ndubuaku

Overview

I mean, do not train or fine-tune LLMs on your laptop, traing is done at much higher precision than inference (float32 or bfload16). Also, additional memory is often used for the gradients, optimizer states, and batch size. So, 4 - 6x the model size. For simplicity, around 8-24G of RAM per 1B params.

HOWEVER, if you must do so on a laptop for whatever weird reason, this library implements most language models such that only the weights for each layer is loaded to the RAM, it implements LoRA fine-tuning such that frozen params are memory-mapped rather than loaded.

Note the following:

  1. Compute intensity = computation time / communication time, and maximisin this means maximising GPU utilisation.
  2. Many computations in transformer models can be parallelised, QKV projections for example.
  3. Most operations in transformers follow the signature A @ B * Scale, A.K.A scaled dot-product.
  4. Q @ K.T / sqrt(dimK) is obiously equivalent to Q @ K.T * dimK^(-1/2)
  5. But Lora_A @ Lora_B = Lora_A @ Lora_B * 1, also A * B = I @ A * B, and so on.

We expressed the transformer forward pass and the backward vector-jacobian products for each layer as a bunch of scaled matmuls, which are bundled together and executed in parallel across different CPU cores as C++ extensions to bypass GIL. This concept makes it easy for an upcoming feature, where each bundle could be distributed across your friends' laptops, such that they only execute one operation called Bundled Scaled Matmul. You're welcome.

Limitations

  1. Gradient accumulation, gradient checkpointing and lazy execution trade time-complexity for memory-efficiency, but you have no choice, do you?
  2. Yeah...your laptop will definitley heat up, GPUs burn up at data centers and cost so much to cool, your laptop is not special.

Supported Models

  • deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
  • Qwen/Qwen2.5-0.5B
  • Qwen/Qwen2.5-0.5B-Instruct
  • Qwen/Qwen2.5-1.5B
  • Qwen/Qwen2.5-1.5B-Instruct
  • Qwen/Qwen2.5-3B
  • Qwen/Qwen2.5-3B-Instruct

Getting Started

  1. pip install sllm-lib
  2. Initialize the model:
    from sllm.nn import SuperLazyLanguageModel
    from sllm.config import Config
    
    config = Config(
        model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        lora_alpha=32,
        lora_r=8,
        lora_dropout=0.1,
    )
    
    model = SuperLazyLanguageModel(config)
    
    # Train like a normal pytorch model
  3. You can use SLLM functionalities:
    import torch
    from datasets import load_dataset
    
    from sllm.nn import SuperLazyLanguageModel
    from sllm.train import sft, prepare_dataset
    
    torch.manual_seed(42)
    
    name = "Qwen/Qwen2-0.5B-Instruct"
    dataset = load_dataset("yahma/alpaca-cleaned", split="train[:200]")
    
    dataset = prepare_dataset(
       model_name=name, 
       instructions=dataset["instruction"], 
       responses=dataset["output"], 
       inputs=dataset["input"],
       max_seq_len=256,
    )
    
    model = SuperLazyLanguageModel(
       name=name, 
       lora_alpha=32, 
       lora_r=8, 
       lora_dropout=0.1,
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    sft(model=model, dataset=dataset, optimizer=optimizer, batch_size=8, epochs=3)

Contributing

Whether you’re improving documentation, optimizing kernels, or adding new features, your help is invaluable.

  1. Create a feature branch (git checkout -b feature/awesome-improvement).
  2. Commit your changes (git commit -m 'Add awesome feature').
  3. Push to the branch (git push origin feature/awesome-improvement).
  4. Open a Pull Request.

About

Hand-derived memory-efficient super lazy PyTorch VJPs for training LLMs on laptop, all using one op (bundled scaled matmuls).

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published