@@ -48,9 +48,11 @@ def merge(lefts: torch.Tensor, rights: torch.Tensor) -> torch.Tensor:
4848def scan (
4949 gates : torch .Tensor ,
5050 tokens : torch .Tensor ,
51- mul = torch .mul ,
52- add = torch .add ,
53- zeros_like = torch .zeros_like
51+ mul : Callable = torch .mul ,
52+ add : Callable = torch .add ,
53+ zeros_like : Callable = torch .zeros_like ,
54+ ones_like : Callable = torch .ones_like ,
55+ reverse : bool = False
5456) -> torch .Tensor :
5557 """Solve a first-order recurrence relation using a reference torch implementation:
5658
@@ -65,13 +67,16 @@ def scan(
6567 mul (callable): multiplication function, defaults to torch.mul
6668 add (callable): addition function, defaults to torch.add
6769 zeros_like (callable): function to create a tensor of zeros like the input, defaults to torch.zeros_like
70+ ones_like (callable): function to create a tensor of ones like the input, defaults to torch.ones_like
71+ reverse (bool): whether to solve the recurrence in reverse order, defaults to False
6872
6973 Returns:
7074 (torch.Tensor): shape (B, C, T)
7175 """
7276 B ,C ,T = tokens .size ()
7377 level = int (math .log2 (T ))
74- return add (mul (scan1 (gates , tokens , mul , add , zeros_like , level = level ), gates ), tokens )
78+ _ , x = scan1 (gates , tokens , mul , add , zeros_like , ones_like , level = level , reverse = reverse )
79+ return add (mul (x , gates ), tokens )
7580
7681
7782def scan1 (
@@ -80,19 +85,29 @@ def scan1(
8085 mul : Callable ,
8186 add : Callable ,
8287 zeros_like : Callable ,
83- level : int
88+ ones_like : Callable ,
89+ level : int ,
90+ reverse : bool = False
8491):
85- left_gates , right_gates = split (gates )
86- left_x , right_x = split (tokens )
92+ if reverse :
93+ right_gates , left_gates = split (gates )
94+ right_x , left_x = split (tokens )
95+ else :
96+ left_gates , right_gates = split (gates )
97+ left_x , right_x = split (tokens )
8798
8899 # up: sum together
89100 gates = mul (left_gates , right_gates )
90101 tokens = add (mul (right_gates , left_x ), right_x )
91102
92103 if level == 1 :
93- root_x = zeros_like (tokens )
104+ root_gates , root_x = ones_like ( tokens ), zeros_like (tokens )
94105 else :
95- root_x = scan1 (gates , tokens , mul , add , zeros_like , level = level - 1 )
106+ root_gates , root_x = scan1 (gates , tokens , mul , add , zeros_like , ones_like , level = level - 1 , reverse = reverse )
96107
97- # down: left is root, right is left (+) right
98- return merge (root_x , add (mul (root_x , left_gates ), left_x ))
108+ if reverse :
109+ # down: right is root, left is left (+) right
110+ return merge (mul (root_gates , left_gates ), root_gates ), merge (add (mul (root_x , left_gates ), left_x ), root_x )
111+ else :
112+ # down: left is root, right is left (+) right
113+ return merge (root_gates , mul (root_gates , left_gates )), merge (root_x , add (mul (root_x , left_gates ), left_x ))
0 commit comments