-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
106 lines (74 loc) · 3.16 KB
/
test.py
File metadata and controls
106 lines (74 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn.functional as F
from torch import nn
from math import sqrt
from FlashMetal import FlashAttentionForward, FlashAttentionBackward, fetchPipeline
forward_pip, backward_pip = fetchPipeline()
n_embed = 64
n_heads = 8
q = torch.randn(2,n_heads,1024, 8, requires_grad=True, device="mps")
k = torch.randn(2,n_heads,1024, 8, requires_grad=True, device="mps")
v = torch.randn(2,n_heads,1024, 8, requires_grad=True, device="mps")
s = q @ k.transpose(-1,-2)
s /= sqrt(8)
mask = torch.tril(torch.ones_like(s)).to("mps")
s_masked = torch.where(mask == 1, s, torch.tensor(float('-inf')).to("mps"))
P = F.softmax(s_masked, -1)
o_test = (torch.matmul(P, v))
loss = torch.mean(o_test)
loss.backward()
print(q.grad)
q.grad = torch.zeros_like(q.grad)
#print(o_test)
#dO = torch.randn_like(q, device='mps')
#o1 = (s_masked @ v)
#dP = torch.matmul(dO, v.transpose(-1, -2))
#dS = P * (dP - torch.sum(dP * P, dim=-1, keepdim=True))
#print(torch.matmul(dS, k))
class FlashAttentionAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, query, key, value):
batch_size, num_heads, N_seq, n_embed = query.size()
out = torch.empty_like(value, requires_grad=True, device='mps')
row_max = torch.empty((batch_size, num_heads, N_seq), device='mps')
row_sum = torch.empty((batch_size, num_heads, N_seq), device='mps')
out, row_max, row_sum = FlashAttentionForward(query, key, value, out, row_max, row_sum)
ctx.save_for_backward(query, key, value, out, row_max, row_sum)
return out
@staticmethod
def backward(ctx, grad_output):
query, key, value, out, row_max, row_sum = ctx.saved_tensors
out_dQ = torch.zeros_like(query, device='mps')
out_dK = torch.zeros_like(key, device='mps')
out_dV = torch.zeros_like(value, device='mps')
res_metal = FlashAttentionBackward(query, key, value, out, grad_output, out_dQ, out_dK, out_dV, row_sum, row_max)
grad_query, grad_key, grad_value = res_metal
return grad_query, grad_key, grad_value
out = FlashAttentionAutograd.apply(q,k,v)
#print(out)
#diff = out - o_test
#print(diff)
loss = torch.mean(out)
loss.backward()
print(q.grad)
class MHAttention(nn.Module):
def __init__(self):
super().__init__()
self.head_size = 8
self.batch_qkv_matrices = nn.Linear(n_embed, self.head_size * n_heads * 3, bias=False)
self.projection = nn.Linear(n_embed, n_embed)
# self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
q,k,v = self.batch_qkv_matrices(x).split(self.head_size * n_heads, dim=-1) # Now Q,K,V of dim B, T, head size * n_heads
q = q.view(B, T, n_heads, self.head_size).transpose(1,2) # Now of shape B, n_heads, T, head_size for BMM
k = k.view(B, T, n_heads, self.head_size).transpose(1,2)
v = v.view(B, T, n_heads, self.head_size).transpose(1,2)
out = FlashAttentionAutograd.apply(q,k,v)
return out
#x = torch.randn(1,1024, 384, device="mps")
#mh = MHAttention().to("mps")
#out = mh(x)
#loss = torch.mean(out)
#loss.backward()
#print("Gradient check passed:", test)