Skip to content

梯度爆炸 #8

@tanjiarui

Description

@tanjiarui

你好,最近把你们模改的tutel拿来运行了一下,跑了一个简单的模型来验证代码。但是不管怎么调整都出现了损失消失的问题。可以帮我看一下是超参数不对还是模型结构的问题吗?谢谢

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from torch.optim import Adam
from tutel import moe


# 定义多头注意力
class MultiHeadAttention(nn.Module):
	def __init__(self, dim, heads=8):
		super().__init__()
		self.heads = heads
		self.scale = dim ** -0.5
		self.to_qkv = nn.Linear(dim, dim * 3)
		self.out_proj = nn.Linear(dim, dim)

	def forward(self, x):
		B, T, D = x.shape
		qkv = self.to_qkv(x).chunk(3, dim=-1)
		q, k, v = [t.reshape(B, T, self.heads, D // self.heads).transpose(1, 2) for t in qkv]
		attn_weights = (q @ k.transpose(-2, -1)) * self.scale
		attn = attn_weights.softmax(dim=-1)
		out = (attn @ v).transpose(1, 2).reshape(B, T, D)
		return self.out_proj(out)


# Transformer block + GMoE FFN
class TransformerBlock(nn.Module):
	def __init__(self, dim=512, num_experts=4, hidden_dim=2048):
		super().__init__()
		self.attn = MultiHeadAttention(dim)
		self.norm1 = nn.LayerNorm(dim)
		self.norm2 = nn.LayerNorm(dim)
		self.norm3 = nn.LayerNorm(dim)
		self.moe_layer = moe.moe_layer(gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}, experts={'type': 'ffn', 'count_per_node': num_experts, 'hidden_size_per_expert': hidden_dim, 'activation_fn': lambda x: nn.functional.relu(x)}, group=2, model_dim=dim)
		self.ffn = nn.Sequential(nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim))

	def forward(self, x):
		x = x + self.attn(self.norm1(x))
		moe_out = self.moe_layer(self.norm2(x))
		return x + moe_out

	# def forward(self, x):
	# 	x = x + self.attn(self.norm1(x))
	# 	x = x + self.norm3(self.ffn(self.norm2(x)))
	# 	return x


# 🧪 模拟训练数据
def build_dataloader(batch_size=8):
	X = torch.randn(100, 16, 64)
	W = torch.randn(64, 3)
	Y = (X.mean(dim=1) @ W).argmax(dim=1)
	dataset = TensorDataset(X, Y)
	return DataLoader(dataset, batch_size=batch_size, shuffle=True)


# 🎯 简单分类头(接在 TransformerBlock 后)
class ClassificationModel(nn.Module):
	def __init__(self, dim=512, num_experts=4, hidden_dim=2048, num_classes=10):
		super().__init__()
		self.projection = nn.Linear(64, dim)
		self.transformer = TransformerBlock(dim, num_experts, hidden_dim)
		self.pool = nn.AdaptiveAvgPool1d(1)
		self.classifier = nn.Linear(dim, num_classes)

	def forward(self, x):
		x = self.projection(x)
		x = self.transformer(x)
		x = x.transpose(1, 2)  # [B, D, T]
		x = self.pool(x).squeeze(-1)  # [B, D]
		logits = self.classifier(x)  # [B, num_classes]
		return logits


# 🚀 开始训练
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataloader = build_dataloader()
model = ClassificationModel(dim=128, num_experts=4, hidden_dim=1024, num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)
for epoch in range(1, 200):
	epoch_loss = 0
	epoch_correct = 0
	epoch_total = 0
	for x, y in dataloader:
		x, y = x.to(device), y.to(device)
		optimizer.zero_grad()
		# model.transformer.moe_layer.begin_record_routing()
		logits = model(x)
		loss = criterion(logits, y) + .0001 * model.transformer.moe_layer.l_aux
		loss.backward()

		# for p in model.parameters():
		# 	if not hasattr(p, 'skip_allreduce') and p.grad is not None:
		# 		p.grad = net.simple_all_reduce(p.grad)

		optimizer.step()

		prediction = logits.argmax(dim=1, keepdim=True)
		correct = prediction.eq(y.view_as(prediction)).sum().item()
		total_items = int(logits.size(0))

		epoch_loss += loss.item()
		epoch_correct += correct
		epoch_total += total_items

	# if epoch % 10 == 0:
	# 	model.transformer.moe_layer.adaptive_update_experts()
	# 	model.transformer.moe_layer.end_record_routing()
	if epoch % 10 == 0:
		print(f'epoch {epoch}: loss = {epoch_loss / len(dataloader):.4f}, accuracy = {epoch_correct / epoch_total * 100:.2f}%')

不管参数gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}还是gate_type={'type': 'gated_multi_gate', 'max_expert_num': 16},运行出来的损失都消失了

Gate types:  ['LinearTopKGate']
4
epoch 10: loss = nan, accuracy = 36.00%
epoch 20: loss = nan, accuracy = 36.00%
epoch 30: loss = nan, accuracy = 36.00%
epoch 40: loss = nan, accuracy = 36.00%
epoch 50: loss = nan, accuracy = 36.00%
epoch 60: loss = nan, accuracy = 36.00%
epoch 70: loss = nan, accuracy = 36.00%
epoch 80: loss = nan, accuracy = 36.00%
epoch 90: loss = nan, accuracy = 36.00%
epoch 100: loss = nan, accuracy = 36.00%
epoch 110: loss = nan, accuracy = 36.00%
epoch 120: loss = nan, accuracy = 36.00%
epoch 130: loss = nan, accuracy = 36.00%
epoch 140: loss = nan, accuracy = 36.00%
epoch 150: loss = nan, accuracy = 36.00%
epoch 160: loss = nan, accuracy = 36.00%
epoch 170: loss = nan, accuracy = 36.00%
epoch 180: loss = nan, accuracy = 36.00%
epoch 190: loss = nan, accuracy = 36.00%

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions