Skip to content

Commit f4c90b6

Browse files
committed
test
1 parent ef170fd commit f4c90b6

File tree

11 files changed

+761
-34
lines changed

11 files changed

+761
-34
lines changed

lzero/Tool.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import numpy as np
3+
from typing import List, Tuple
4+
5+
def compute_gradient_conflicts(gradients: List[torch.Tensor]) -> dict:
6+
"""
7+
计算多个梯度之间的冲突
8+
9+
Args:
10+
gradients: 梯度列表,每个元素是一个梯度张量
11+
12+
Returns:
13+
包含各种冲突指标的字典
14+
"""
15+
results = {}
16+
n_gradients = len(gradients)
17+
18+
# 确保所有梯度形状相同
19+
assert all(g.shape == gradients[0].shape for g in gradients), "梯度形状必须相同"
20+
21+
# 1. 余弦相似度矩阵
22+
cosine_sim_matrix = torch.zeros(n_gradients, n_gradients)
23+
for i in range(n_gradients):
24+
for j in range(n_gradients):
25+
cos_sim = torch.cosine_similarity(
26+
gradients[i].flatten(),
27+
gradients[j].flatten(),
28+
dim=0
29+
)
30+
cosine_sim_matrix[i, j] = cos_sim
31+
32+
results['cosine_similarity_matrix'] = cosine_sim_matrix
33+
34+
# 2. 梯度冲突得分 (负余弦相似度的平均)
35+
# 排除对角线元素
36+
mask = ~torch.eye(n_gradients, dtype=bool)
37+
conflict_scores = -cosine_sim_matrix[mask]
38+
results['avg_conflict_score'] = conflict_scores.mean().item()
39+
results['max_conflict_score'] = conflict_scores.max().item()
40+
41+
# 3. 点积矩阵
42+
dot_product_matrix = torch.zeros(n_gradients, n_gradients)
43+
for i in range(n_gradients):
44+
for j in range(n_gradients):
45+
dot_prod = torch.dot(gradients[i].flatten(), gradients[j].flatten())
46+
dot_product_matrix[i, j] = dot_prod
47+
48+
results['dot_product_matrix'] = dot_product_matrix
49+
50+
# 4. 梯度范数
51+
gradient_norms = [torch.norm(g).item() for g in gradients]
52+
results['gradient_norms'] = gradient_norms
53+
54+
# 5. 冲突强度 (基于负点积)
55+
negative_dot_products = []
56+
for i in range(n_gradients):
57+
for j in range(i+1, n_gradients):
58+
dot_prod = torch.dot(gradients[i].flatten(), gradients[j].flatten())
59+
if dot_prod < 0: # 负点积表示冲突
60+
negative_dot_products.append(-dot_prod.item())
61+
62+
results['num_conflicting_pairs'] = len(negative_dot_products)
63+
results['avg_conflict_intensity'] = np.mean(negative_dot_products) if negative_dot_products else 0
64+
65+
return results
66+
67+
# 使用示例
68+
def example_usage():
69+
# 生成示例梯度
70+
torch.manual_seed(42)
71+
gradients = [
72+
torch.randn(100), # 梯度1
73+
torch.randn(100), # 梯度2
74+
torch.randn(100), # 梯度3
75+
]
76+
77+
# 计算冲突
78+
conflicts = compute_gradient_conflicts(gradients)
79+
80+
print("梯度冲突分析结果:")
81+
print(f"平均冲突得分: {conflicts['avg_conflict_score']:.4f}")
82+
print(f"最大冲突得分: {conflicts['max_conflict_score']:.4f}")
83+
print(f"冲突梯度对数量: {conflicts['num_conflicting_pairs']}")
84+
print(f"平均冲突强度: {conflicts['avg_conflict_intensity']:.4f}")
85+
print(f"梯度范数: {conflicts['gradient_norms']}")
86+
print("\n余弦相似度矩阵:")
87+
print(conflicts['cosine_similarity_matrix'])
88+
89+
90+
if __name__ == "__main__":
91+
example_usage()

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def train_unizero_multitask_segment_ddp(
521521
# 编译配置
522522
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
523523
# 创建共享的policy
524-
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
524+
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # MOE
525525

526526
# 加载预训练模型(如果提供)
527527
if model_path is not None:

lzero/model/common.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,68 @@ def remove_hooks(self):
248248
self.forward_handler.remove()
249249
self.backward_handler.remove()
250250

251+
# # modified by tangjia
252+
# class ModelGradientHook:
253+
254+
255+
# def __init__(self):
256+
# """
257+
# Overview:
258+
# Class to capture gradients at model output.
259+
# """
260+
# self.output_grads = []
261+
262+
# def setup_hook(self, model):
263+
# # Hook to capture gradients at model output
264+
# self.backward_handler = model.register_full_backward_hook(self.backward_hook)
265+
266+
# def backward_hook(self, module, grad_input, grad_output):
267+
# with torch.no_grad():
268+
# # 保存输出梯度
269+
# if grad_output[0] is not None:
270+
# self.output_grads.append(grad_output[0].clone())
271+
272+
# def analyze(self):
273+
# if not self.output_grads:
274+
# return None
275+
276+
# # Calculate norms of output gradients
277+
# grad_norms = [torch.norm(g, p=2, dim=1).mean() for g in self.output_grads]
278+
# avg_grad_norm = torch.mean(torch.stack(grad_norms))
279+
# max_grad_norm = torch.max(torch.stack(grad_norms))
280+
# min_grad_norm = torch.min(torch.stack(grad_norms))
281+
282+
# # Clear stored data and delete tensors to free memory
283+
# self.clear_data()
284+
285+
# # Optionally clear CUDA cache
286+
# if torch.cuda.is_available():
287+
# torch.cuda.empty_cache()
288+
289+
# return avg_grad_norm, max_grad_norm, min_grad_norm
290+
291+
# def clear_data(self):
292+
# del self.output_grads[:]
293+
294+
# def remove_hooks(self):
295+
# self.backward_handler.remove()
296+
297+
# 使用示例
298+
# monitor = ModelGradientMonitor()
299+
# monitor.setup_hook(model)
300+
#
301+
# # 训练过程中...
302+
# loss.backward()
303+
#
304+
# # 获取梯度信息
305+
# grad_norm = monitor.get_gradient_norm()
306+
# grad_stats = monitor.get_gradient_stats()
307+
#
308+
# # 清理数据
309+
# monitor.clear_data()
310+
#
311+
# # 训练结束后移除hook
312+
# monitor.remove_hook()
251313

252314
class DownSample(nn.Module):
253315

lzero/model/unizero_model_multitask.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from easydict import EasyDict
77

88
from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
9-
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook
9+
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook #,ModelGradientHook
1010
from .unizero_world_models.tokenizer import Tokenizer
1111
from .unizero_world_models.world_model_multitask import WorldModelMT
1212

@@ -189,6 +189,11 @@ def __init__(
189189
self.encoder_hook = FeatureAndGradientHook()
190190
self.encoder_hook.setup_hooks(self.representation_network)
191191

192+
# if True: # Fixme: for debug
193+
# # 增加对encoder的hook,监控传播到encoder 上的梯度
194+
# self.encoder_output_hook = ModelGradientHook()
195+
# self.encoder_output_hook.setup_hook(self.representation_network)
196+
192197
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type)
193198
self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer)
194199
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from .transformer import Transformer, TransformerConfig

0 commit comments

Comments
 (0)