Skip to content

LemonAttn/mini_transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mini_transformer

mini_transformer is designed to rapidly construct state-of-the-art new architectures for transformer

Install

1、Install the GPU version of PyTorch for better performance.

pip3 install torch --index-url https://download.pytorch.org/whl/cu124

2、Install mini_transformer

pip install mini_transformer

Usage

1、use GQA attention(https://arxiv.org/abs/2305.13245)

import torch
from mini_transformer import Transformer

data = torch.randint(0, 6400, (2, 512))
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
model = Transformer(attn_name = 'gqa')
# When mlp_name == 'mlp', aux_loss defaults to returning 0, 
# meaning the aux_loss does not take effect.
out, aux_loss = model(data, start_pos, mask)
print(out.shape, aux_loss)

2、use MLA attention(https://arxiv.org/abs/2412.19437)

import torch
from mini_transformer import Transformer

data = torch.randint(0, 6400, (2, 512))
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
model = Transformer(attn_name = 'mla')
# When mlp_name == 'mlp', aux_loss defaults to returning 0, 
# meaning the aux_loss does not take effect.
out, aux_loss = model(data, start_pos, mask)
print(out.shape, aux_loss)

3、use MOE architecture

import torch
from mini_transformer import Transformer

data = torch.randint(0, 6400, (2, 512))
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
model = Transformer(attn_name = 'mla', mlp_name = 'moe')
# When mlp_name == 'moe_loss' and the model is in training mode, 
# the aux_loss is computed and should be added to the final loss 
# to balance the load of each router.
out, aux_loss = model(data, start_pos, mask)
print(out.shape, aux_loss)

4、inference(kv_cache)

import torch
from mini_transformer import Transformer

data = torch.randint(0, 6400, (2, 512))
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
model = Transformer(attn_name = 'mla')
out, _ = model(data, start_pos, mask, use_cache = True)
print(out.shape)

inference_data = torch.randint(0, 6400, (2, 1))
start_pos = data.shape[1]
out, _ = model(inference_data, start_pos, use_cache = True)
print(out.shape)

5、inference(no kv_cache)

import torch
from mini_transformer import Transformer

# init
data = torch.randint(0, 6400, (2, 512))
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
model = Transformer(attn_name = 'mla')
out, _ = model(data, start_pos, mask)
print(out.shape)

# inference
data = torch.cat((data, torch.randint(0, 6400, (2, 1))), dim = 1)
start_pos = 0
mask = torch.triu(torch.full((data.shape[1], data.shape[1]), float('-inf')), 1)
out, _ = model(data, start_pos)
print(out.shape)

Support Models

1、VIT

import torch
from mini_transformer.model import VIT

data = torch.randn(2, 3, 224, 224)
model = VIT(num_classes = None) # img encoder(clip、sigLip、janus)
out = model.forward(data) # [b, d]
print(out.shape)

model = VIT(num_classes = 1000) # classification
out = model.forward(data) # [b, 1000]
print(out.shape)

About

最小Transformer架构,能够快速搭建现在各种Transformer架构模型

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages