Skip to content

Latest commit

 

History

History
222 lines (152 loc) · 6.41 KB

File metadata and controls

222 lines (152 loc) · 6.41 KB

Converting GPT to Llama

This folder contains code for converting the GPT implementation from chapter 4 and 5 to Meta AI's Llama architecture in the following recommended reading order:

 

Using Llama 3.2 via the llms-from-scratch package

For an easy way to use the Llama 3.2 1B and 3B models, you can also use the llms-from-scratch PyPI package based on the source code in this repository at pkg/llms_from_scratch.

 

1) Installation

pip install llms_from_scratch blobfile

(Note that blobfile is needed to load the tokenizer.)

 

2) Model and text generation settings

Specify which model to use:

MODEL_FILE = "llama3.2-1B-instruct.pth"
# MODEL_FILE = "llama3.2-1B-base.pth"
# MODEL_FILE = "llama3.2-3B-instruct.pth"
# MODEL_FILE = "llama3.2-3B-base.pth"

Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example.

MODEL_CONTEXT_LENGTH = 8192  # Supports up to 131_072

# Text generation settings
if "instruct" in MODEL_FILE:
    PROMPT = "What do llamas eat?"
else:
    PROMPT = "Llamas eat"

MAX_NEW_TOKENS = 150
TEMPERATURE = 0.
TOP_K = 1

 

3) Weight download and loading

This automatically downloads the weight file based on the model choice above:

import os
import urllib.request

url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"

if not os.path.exists(MODEL_FILE):
    urllib.request.urlretrieve(url, MODEL_FILE)
    print(f"Downloaded to {MODEL_FILE}")

The model weights are then loaded as follows:

import torch
from llms_from_scratch.llama3 import Llama3Model

if "1B" in MODEL_FILE:
    from llms_from_scratch.llama3 import LLAMA32_CONFIG_1B as LLAMA32_CONFIG
elif "3B" in MODEL_FILE:
    from llms_from_scratch.llama3 import LLAMA32_CONFIG_3B as LLAMA32_CONFIG
else:
    raise ValueError("Incorrect model file name")

LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH

model = Llama3Model(LLAMA32_CONFIG)
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))

device = (
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cpu")
)
model.to(device)

 

4) Initialize tokenizer

The following code downloads and initializes the tokenizer:

from llms_from_scratch.llama3 import Llama3Tokenizer, ChatFormat, clean_text

TOKENIZER_FILE = "tokenizer.model"

url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{TOKENIZER_FILE}"

if not os.path.exists(TOKENIZER_FILE):
    urllib.request.urlretrieve(url, TOKENIZER_FILE)
    print(f"Downloaded to {TOKENIZER_FILE}")
    
tokenizer = Llama3Tokenizer("tokenizer.model")

if "instruct" in MODEL_FILE:
    tokenizer = ChatFormat(tokenizer)

 

5) Generating text

Lastly, we can generate text via the following code:

import time

from ch05 import (
    generate,
    text_to_token_ids,
    token_ids_to_text
)

torch.manual_seed(123)

start = time.time()

token_ids = generate(
    model=model,
    idx=text_to_token_ids(PROMPT, tokenizer).to(device),
    max_new_tokens=MAX_NEW_TOKENS,
    context_size=LLAMA32_CONFIG["context_length"],
    top_k=TOP_K,
    temperature=TEMPERATURE
)

total_time = time.time() - start
print(f"Time: {total_time:.2f} sec")
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")

if torch.cuda.is_available():
    max_mem_bytes = torch.cuda.max_memory_allocated()
    max_mem_gb = max_mem_bytes / (1024 ** 3)
    print(f"Max memory allocated: {max_mem_gb:.2f} GB")

output_text = token_ids_to_text(token_ids, tokenizer)

if "instruct" in MODEL_FILE:
    output_text = clean_text(output_text)

print("\n\nOutput text:\n\n", output_text)

When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:

Time: 3.17 sec
50 tokens/sec
Max memory allocated: 2.91 GB


Output text:

 Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:

1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.
2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.
3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.
4. Other plants: Llamas will also eat other plants, such as clover, dandelions, and wild grasses.

It's worth noting that the specific diet of llamas can vary depending on factors such as the breed,

 

Pro tip 1: speed up inference with FlashAttention

Instead of using Llama3Model, you can use Llama3ModelFast as a drop-in replacement. For more information, I encourage you to inspect the pkg/llms_from_scratch/llama3.py code.

The Llama3ModelFast replaces my from-scratch scaled dot-product code in the GroupedQueryAttention module with PyTorch's scaled_dot_product function, which uses FlashAttention on Ampere GPUs or newer.

The following table shows a performance comparison on an A100:

Tokens/sec Memory
Llama3Model 50 2.91 GB
Llama3ModelFast 58 2.85 GB

 

Pro tip 2: speed up inference with compilation

For up to a 4× speed-up, replace

model.to(device)

with

model = torch.compile(model)
model.to(device)

Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first generate call.

The following table shows a performance comparison on an A100 for consequent generate calls:

Tokens/sec Memory
Llama3Model 156 3.12 GB
Llama3ModelFast 159 2.84 GB