Skip to content

Commit b9edbad

Browse files
committed
implement reverse reference scan, add test for reference backward
1 parent 300654c commit b9edbad

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

accelerated_scan/ref.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def merge(lefts: torch.Tensor, rights: torch.Tensor) -> torch.Tensor:
4848
def scan(
4949
gates: torch.Tensor,
5050
tokens: torch.Tensor,
51-
mul=torch.mul,
52-
add=torch.add,
53-
zeros_like=torch.zeros_like
51+
mul: Callable = torch.mul,
52+
add: Callable = torch.add,
53+
zeros_like: Callable = torch.zeros_like,
54+
ones_like: Callable = torch.ones_like,
55+
reverse: bool = False
5456
) -> torch.Tensor:
5557
"""Solve a first-order recurrence relation using a reference torch implementation:
5658
@@ -65,13 +67,16 @@ def scan(
6567
mul (callable): multiplication function, defaults to torch.mul
6668
add (callable): addition function, defaults to torch.add
6769
zeros_like (callable): function to create a tensor of zeros like the input, defaults to torch.zeros_like
70+
ones_like (callable): function to create a tensor of ones like the input, defaults to torch.ones_like
71+
reverse (bool): whether to solve the recurrence in reverse order, defaults to False
6872
6973
Returns:
7074
(torch.Tensor): shape (B, C, T)
7175
"""
7276
B,C,T = tokens.size()
7377
level = int(math.log2(T))
74-
return add(mul(scan1(gates, tokens, mul, add, zeros_like, level=level), gates), tokens)
78+
_, x = scan1(gates, tokens, mul, add, zeros_like, ones_like, level=level, reverse=reverse)
79+
return add(mul(x, gates), tokens)
7580

7681

7782
def scan1(
@@ -80,19 +85,29 @@ def scan1(
8085
mul: Callable,
8186
add: Callable,
8287
zeros_like: Callable,
83-
level: int
88+
ones_like: Callable,
89+
level: int,
90+
reverse: bool = False
8491
):
85-
left_gates, right_gates = split(gates)
86-
left_x, right_x = split(tokens)
92+
if reverse:
93+
right_gates, left_gates = split(gates)
94+
right_x, left_x = split(tokens)
95+
else:
96+
left_gates, right_gates = split(gates)
97+
left_x, right_x = split(tokens)
8798

8899
# up: sum together
89100
gates = mul(left_gates, right_gates)
90101
tokens = add(mul(right_gates, left_x), right_x)
91102

92103
if level == 1:
93-
root_x = zeros_like(tokens)
104+
root_gates, root_x = ones_like(tokens), zeros_like(tokens)
94105
else:
95-
root_x = scan1(gates, tokens, mul, add, zeros_like, level=level-1)
106+
root_gates, root_x = scan1(gates, tokens, mul, add, zeros_like, ones_like, level=level-1, reverse=reverse)
96107

97-
# down: left is root, right is left (+) right
98-
return merge(root_x, add(mul(root_x, left_gates), left_x))
108+
if reverse:
109+
# down: right is root, left is left (+) right
110+
return merge(mul(root_gates, left_gates), root_gates), merge(add(mul(root_x, left_gates), left_x), root_x)
111+
else:
112+
# down: left is root, right is left (+) right
113+
return merge(root_gates, mul(root_gates, left_gates)), merge(root_x, add(mul(root_x, left_gates), left_x))

tests/test_eq.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,31 @@ def test_eq_backward(scan, seed, seqlen, dtype):
5959

6060
assert torch.allclose(gates_grad, gates_ref.grad, atol=atol[dtype])
6161
assert torch.allclose(tokens_grad, tokens_ref.grad, atol=atol[dtype])
62-
62+
63+
64+
@pytest.mark.parametrize("seed", [1])
65+
@pytest.mark.parametrize("seqlen", seqlens)
66+
def test_eq_ref_reverse(seed, seqlen):
67+
generator = torch.Generator().manual_seed(seed)
68+
B,C,T = 1, 1, seqlen
69+
f = torch.randn(B, C, T, generator=generator, requires_grad=True)
70+
x = torch.randn(B, C, T, generator=generator, requires_grad=True)
71+
72+
c = scan_ref(f, x)
73+
74+
dldc = torch.ones_like(c)
75+
76+
fpx = torch.cat([f, torch.ones_like(f[:, :, :1])], dim=-1)[:, :, 1:].contiguous()
77+
dcdx = scan_ref(fpx, dldc, reverse=True)
78+
cp = torch.cat([torch.zeros_like(c[:, :, :1]), c], dim=-1)[:, :, :-1].contiguous()
79+
dcdf = dcdx * cp
80+
81+
c.sum().backward()
82+
print(dcdx, 'dcdx')
83+
print(x.grad, 'x.grad')
84+
print((x.grad - dcdx).abs().max(), 'x error')
85+
assert torch.allclose(x.grad, dcdx, atol=1e-5)
86+
print(dcdf, 'dcdf')
87+
print(f.grad, 'f.grad')
88+
print((f.grad - dcdf).abs().max(), 'f error')
89+
assert torch.allclose(f.grad, dcdf, atol=2e-5)

0 commit comments

Comments
 (0)