-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Description
你好,最近把你们模改的tutel拿来运行了一下,跑了一个简单的模型来验证代码。但是不管怎么调整都出现了损失消失的问题。可以帮我看一下是超参数不对还是模型结构的问题吗?谢谢
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from torch.optim import Adam
from tutel import moe
# 定义多头注意力
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3)
self.out_proj = nn.Linear(dim, dim)
def forward(self, x):
B, T, D = x.shape
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = [t.reshape(B, T, self.heads, D // self.heads).transpose(1, 2) for t in qkv]
attn_weights = (q @ k.transpose(-2, -1)) * self.scale
attn = attn_weights.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, D)
return self.out_proj(out)
# Transformer block + GMoE FFN
class TransformerBlock(nn.Module):
def __init__(self, dim=512, num_experts=4, hidden_dim=2048):
super().__init__()
self.attn = MultiHeadAttention(dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.moe_layer = moe.moe_layer(gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}, experts={'type': 'ffn', 'count_per_node': num_experts, 'hidden_size_per_expert': hidden_dim, 'activation_fn': lambda x: nn.functional.relu(x)}, group=2, model_dim=dim)
self.ffn = nn.Sequential(nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim))
def forward(self, x):
x = x + self.attn(self.norm1(x))
moe_out = self.moe_layer(self.norm2(x))
return x + moe_out
# def forward(self, x):
# x = x + self.attn(self.norm1(x))
# x = x + self.norm3(self.ffn(self.norm2(x)))
# return x
# 🧪 模拟训练数据
def build_dataloader(batch_size=8):
X = torch.randn(100, 16, 64)
W = torch.randn(64, 3)
Y = (X.mean(dim=1) @ W).argmax(dim=1)
dataset = TensorDataset(X, Y)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 🎯 简单分类头(接在 TransformerBlock 后)
class ClassificationModel(nn.Module):
def __init__(self, dim=512, num_experts=4, hidden_dim=2048, num_classes=10):
super().__init__()
self.projection = nn.Linear(64, dim)
self.transformer = TransformerBlock(dim, num_experts, hidden_dim)
self.pool = nn.AdaptiveAvgPool1d(1)
self.classifier = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.projection(x)
x = self.transformer(x)
x = x.transpose(1, 2) # [B, D, T]
x = self.pool(x).squeeze(-1) # [B, D]
logits = self.classifier(x) # [B, num_classes]
return logits
# 🚀 开始训练
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataloader = build_dataloader()
model = ClassificationModel(dim=128, num_experts=4, hidden_dim=1024, num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)
for epoch in range(1, 200):
epoch_loss = 0
epoch_correct = 0
epoch_total = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
# model.transformer.moe_layer.begin_record_routing()
logits = model(x)
loss = criterion(logits, y) + .0001 * model.transformer.moe_layer.l_aux
loss.backward()
# for p in model.parameters():
# if not hasattr(p, 'skip_allreduce') and p.grad is not None:
# p.grad = net.simple_all_reduce(p.grad)
optimizer.step()
prediction = logits.argmax(dim=1, keepdim=True)
correct = prediction.eq(y.view_as(prediction)).sum().item()
total_items = int(logits.size(0))
epoch_loss += loss.item()
epoch_correct += correct
epoch_total += total_items
# if epoch % 10 == 0:
# model.transformer.moe_layer.adaptive_update_experts()
# model.transformer.moe_layer.end_record_routing()
if epoch % 10 == 0:
print(f'epoch {epoch}: loss = {epoch_loss / len(dataloader):.4f}, accuracy = {epoch_correct / epoch_total * 100:.2f}%')不管参数gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}还是gate_type={'type': 'gated_multi_gate', 'max_expert_num': 16},运行出来的损失都消失了
Gate types: ['LinearTopKGate']
4
epoch 10: loss = nan, accuracy = 36.00%
epoch 20: loss = nan, accuracy = 36.00%
epoch 30: loss = nan, accuracy = 36.00%
epoch 40: loss = nan, accuracy = 36.00%
epoch 50: loss = nan, accuracy = 36.00%
epoch 60: loss = nan, accuracy = 36.00%
epoch 70: loss = nan, accuracy = 36.00%
epoch 80: loss = nan, accuracy = 36.00%
epoch 90: loss = nan, accuracy = 36.00%
epoch 100: loss = nan, accuracy = 36.00%
epoch 110: loss = nan, accuracy = 36.00%
epoch 120: loss = nan, accuracy = 36.00%
epoch 130: loss = nan, accuracy = 36.00%
epoch 140: loss = nan, accuracy = 36.00%
epoch 150: loss = nan, accuracy = 36.00%
epoch 160: loss = nan, accuracy = 36.00%
epoch 170: loss = nan, accuracy = 36.00%
epoch 180: loss = nan, accuracy = 36.00%
epoch 190: loss = nan, accuracy = 36.00%Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels