|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +from typing import List, Tuple |
| 4 | + |
| 5 | +def compute_gradient_conflicts(gradients: List[torch.Tensor]) -> dict: |
| 6 | + """ |
| 7 | + 计算多个梯度之间的冲突 |
| 8 | + |
| 9 | + Args: |
| 10 | + gradients: 梯度列表,每个元素是一个梯度张量 |
| 11 | + |
| 12 | + Returns: |
| 13 | + 包含各种冲突指标的字典 |
| 14 | + """ |
| 15 | + results = {} |
| 16 | + n_gradients = len(gradients) |
| 17 | + |
| 18 | + # 确保所有梯度形状相同 |
| 19 | + assert all(g.shape == gradients[0].shape for g in gradients), "梯度形状必须相同" |
| 20 | + |
| 21 | + # 1. 余弦相似度矩阵 |
| 22 | + cosine_sim_matrix = torch.zeros(n_gradients, n_gradients) |
| 23 | + for i in range(n_gradients): |
| 24 | + for j in range(n_gradients): |
| 25 | + cos_sim = torch.cosine_similarity( |
| 26 | + gradients[i].flatten(), |
| 27 | + gradients[j].flatten(), |
| 28 | + dim=0 |
| 29 | + ) |
| 30 | + cosine_sim_matrix[i, j] = cos_sim |
| 31 | + |
| 32 | + results['cosine_similarity_matrix'] = cosine_sim_matrix |
| 33 | + |
| 34 | + # 2. 梯度冲突得分 (负余弦相似度的平均) |
| 35 | + # 排除对角线元素 |
| 36 | + mask = ~torch.eye(n_gradients, dtype=bool) |
| 37 | + conflict_scores = -cosine_sim_matrix[mask] |
| 38 | + results['avg_conflict_score'] = conflict_scores.mean().item() |
| 39 | + results['max_conflict_score'] = conflict_scores.max().item() |
| 40 | + |
| 41 | + # 3. 点积矩阵 |
| 42 | + dot_product_matrix = torch.zeros(n_gradients, n_gradients) |
| 43 | + for i in range(n_gradients): |
| 44 | + for j in range(n_gradients): |
| 45 | + dot_prod = torch.dot(gradients[i].flatten(), gradients[j].flatten()) |
| 46 | + dot_product_matrix[i, j] = dot_prod |
| 47 | + |
| 48 | + results['dot_product_matrix'] = dot_product_matrix |
| 49 | + |
| 50 | + # 4. 梯度范数 |
| 51 | + gradient_norms = [torch.norm(g).item() for g in gradients] |
| 52 | + results['gradient_norms'] = gradient_norms |
| 53 | + |
| 54 | + # 5. 冲突强度 (基于负点积) |
| 55 | + negative_dot_products = [] |
| 56 | + for i in range(n_gradients): |
| 57 | + for j in range(i+1, n_gradients): |
| 58 | + dot_prod = torch.dot(gradients[i].flatten(), gradients[j].flatten()) |
| 59 | + if dot_prod < 0: # 负点积表示冲突 |
| 60 | + negative_dot_products.append(-dot_prod.item()) |
| 61 | + |
| 62 | + results['num_conflicting_pairs'] = len(negative_dot_products) |
| 63 | + results['avg_conflict_intensity'] = np.mean(negative_dot_products) if negative_dot_products else 0 |
| 64 | + |
| 65 | + return results |
| 66 | + |
| 67 | +# 使用示例 |
| 68 | +def example_usage(): |
| 69 | + # 生成示例梯度 |
| 70 | + torch.manual_seed(42) |
| 71 | + gradients = [ |
| 72 | + torch.randn(100), # 梯度1 |
| 73 | + torch.randn(100), # 梯度2 |
| 74 | + torch.randn(100), # 梯度3 |
| 75 | + ] |
| 76 | + |
| 77 | + # 计算冲突 |
| 78 | + conflicts = compute_gradient_conflicts(gradients) |
| 79 | + |
| 80 | + print("梯度冲突分析结果:") |
| 81 | + print(f"平均冲突得分: {conflicts['avg_conflict_score']:.4f}") |
| 82 | + print(f"最大冲突得分: {conflicts['max_conflict_score']:.4f}") |
| 83 | + print(f"冲突梯度对数量: {conflicts['num_conflicting_pairs']}") |
| 84 | + print(f"平均冲突强度: {conflicts['avg_conflict_intensity']:.4f}") |
| 85 | + print(f"梯度范数: {conflicts['gradient_norms']}") |
| 86 | + print("\n余弦相似度矩阵:") |
| 87 | + print(conflicts['cosine_similarity_matrix']) |
| 88 | + |
| 89 | + |
| 90 | +if __name__ == "__main__": |
| 91 | + example_usage() |
0 commit comments