22import warnings
33from dataclasses import dataclass
44from functools import reduce # Required in Python 3
5+ from typing import Tuple , Optional
56
67import torch
78
@@ -14,6 +15,12 @@ def prod(iterable):
1415
1516tensor = torch .Tensor
1617
18+
19+ # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
20+ # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
21+
22+
23+
1724"""
1825 This class pools outlier dimensions across layers.
1926 This is particularly important for small models where outlier features
@@ -48,6 +55,51 @@ def get_current_outlier_idx(self):
4855 return torch .Tensor (list (self .outliers )).to (torch .int64 )
4956
5057
58+ def get_inverse_transform_indices (transform_tile : callable , tile_size : Tuple [int , int ]):
59+ """
60+ Compute a permutation of indices that invert the specified (tiled) matrix transformation
61+
62+ :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
63+ :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
64+ :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
65+ :example: transform_tile function for the turing layout (bitsandbytes.functional as F)
66+ :returns: indices
67+ """
68+ d1 , d2 = tile_size
69+ assert 0 < d1 * d2 < 2 ** 64
70+ tile_indices = torch .arange (d1 * d2 , dtype = torch .int64 ).view (d1 , d2 )
71+ # encode each position in tile as a tuple of <= 8 unique bytes
72+ permuted_tile_indices = torch .zeros_like (tile_indices )
73+ for i in range (8 ):
74+ # select i-th byte, apply transformation and trace where each index ended up
75+ ith_dim_indices = torch .div (tile_indices , 256 ** i , rounding_mode = "trunc" ) % 256
76+ sample_tile_i = (ith_dim_indices - 128 ).to (torch .int8 ).contiguous ()
77+ assert torch .all (sample_tile_i .int () + 128 == ith_dim_indices ), "int overflow"
78+ permuted_tile_i = transform_tile (sample_tile_i )
79+ ith_permuted_indices = permuted_tile_i .to (tile_indices .dtype ) + 128
80+ permuted_tile_indices += ith_permuted_indices * (256 ** i )
81+ if d1 * d2 < 256 ** i :
82+ break # if all indices fit in i bytes, stop early
83+ return permuted_tile_indices
84+
85+ def undo_layout (permuted_tensor : torch .Tensor , tile_indices : torch .LongTensor ) -> torch .Tensor :
86+ """
87+ Undo a tiled permutation such as turing or ampere layout
88+
89+ :param permuted_tensor: torch tensor in a permuted layout
90+ :param tile_indices: reverse transformation indices, from get_inverse_transform_indices
91+ :return: contiguous row-major tensor
92+ """
93+ (rows , cols ), (tile_rows , tile_cols ) = permuted_tensor .shape , tile_indices .shape
94+ assert rows % tile_rows == cols % tile_cols == 0 , "tensor must contain a whole number of tiles"
95+ tensor = permuted_tensor .reshape (- 1 , tile_indices .numel ()).t ()
96+ outputs = torch .empty_like (tensor ) # note: not using .index_copy because it was slower on cuda
97+ outputs [tile_indices .flatten ()] = tensor
98+ outputs = outputs .reshape (tile_rows , tile_cols , cols // tile_cols , rows // tile_rows )
99+ outputs = outputs .permute (3 , 0 , 2 , 1 ) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
100+ return outputs .reshape (rows , cols ).contiguous ()
101+
102+
51103class MatMul8bit (torch .autograd .Function ):
52104 @staticmethod
53105 def forward (ctx , A , B , out = None , quant_type = "vector" , precision = None ):
@@ -171,6 +223,8 @@ def backward(ctx, grad_output):
171223
172224@dataclass
173225class MatmulLtState :
226+ tile_indices : Optional [torch .Tensor ] = None
227+ force_no_igemmlt : bool = False
174228 CB = None
175229 CxB = None
176230 SB = None
@@ -202,21 +256,32 @@ def reset_grads(self):
202256 self .SBt = None
203257 self .CBt = None
204258
259+ def get_tile_size (self ):
260+ assert self .formatB in (
261+ "col_turing" ,
262+ "col_ampere" ,
263+ ), f"please find this assert and manually enter tile size for { self .formatB } "
264+ return (8 , 32 ) if self .formatB == "col_turing" else (32 , 32 )
265+
205266
206267class MatMul8bitLt (torch .autograd .Function ):
268+ # forward is the same, but we added the fallback for pre-turing GPUs
269+ # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
270+
207271 @staticmethod
208- def forward (ctx , A , B , out = None , bias = None , state = MatmulLtState ()):
209- # default to pytorch behavior if inputs are empty
272+ def forward (ctx , A , B , out = None , bias = None , state = MatmulLtState ):
273+ using_igemmlt = torch .cuda .get_device_capability (device = A .device ) >= (7 , 5 ) and not state .force_no_igemmlt
274+ # default of pytorch behavior if inputs are empty
210275 ctx .is_empty = False
211276 if prod (A .shape ) == 0 :
212277 ctx .is_empty = True
213278 ctx .A = A
214279 ctx .B = B
215280 ctx .bias = bias
216281 if A .shape [- 1 ] == B .shape [0 ]:
217- return torch .empty (A .shape [:- 1 ]+ B .shape [1 :], dtype = A .dtype , device = A .device )
282+ return torch .empty (A .shape [:- 1 ] + B .shape [1 :], dtype = A .dtype , device = A .device )
218283 else :
219- return torch .empty (A .shape [:- 1 ]+ B .shape [:1 ], dtype = A .dtype , device = A .device )
284+ return torch .empty (A .shape [:- 1 ] + B .shape [:1 ], dtype = A .dtype , device = A .device )
220285
221286 # 1. Quantize A
222287 # 2. Quantize B
@@ -235,9 +300,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
235300 # 1. Quantize A
236301 if len (A .shape ) == 3 :
237302 A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
238- CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (
239- A .to (torch .float16 ), threshold = state .threshold
240- )
303+ CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (A .to (torch .float16 ), threshold = state .threshold )
241304
242305 if state .threshold > 0.0 and coo_tensorA is not None :
243306 if state .has_fp16_weights :
@@ -248,12 +311,12 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
248311 state .subB = B [:, idx ].t ().contiguous ()
249312 state .idx = idx
250313 else :
251- if state .CxB is None :
314+ if state .CxB is None and using_igemmlt :
252315 # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
253316 # we also need to convert it to the turing/ampere format
254317 state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
255318 else :
256- if not state .has_fp16_weights and state .CxB is None :
319+ if not state .has_fp16_weights and state .CxB is None and using_igemmlt :
257320 state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
258321 subA = None
259322
@@ -273,7 +336,10 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
273336 state .SCBt ,
274337 coo_tensorB ,
275338 ) = F .double_quant (B .to (torch .float16 ))
276- state .CxB , state .SB = F .transform (CB , to_order = formatB )
339+ if using_igemmlt :
340+ state .CxB , state .SB = F .transform (CB , to_order = formatB )
341+ else :
342+ state .CB = CB
277343 else :
278344 has_grad = False
279345
@@ -288,35 +354,43 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
288354 # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
289355 # else:
290356 # state.idx = outlier_idx
291- outliers = F .extract_outliers (state .CxB , state .SB , state .idx .int ())
292- state .subB = (
293- (outliers * state .SCB .view (- 1 , 1 ) / 127.0 )
294- .t ()
295- .contiguous ()
296- .to (A .dtype )
297- )
357+ if state .CxB is not None :
358+ outliers = F .extract_outliers (state .CxB , state .SB , state .idx .int ())
359+ else :
360+ outliers = state .CB [:, state .idx .long ()].clone ()
361+
362+ state .subB = (outliers * state .SCB .view (- 1 , 1 ) / 127.0 ).t ().contiguous ().to (A .dtype )
298363 CA [:, state .idx .long ()] = 0
299364 CAt [:, state .idx .long ()] = 0
300365 subA = A [:, state .idx .long ()]
301366
302- shapeB = state .SB [0 ]
367+ shapeB = state .SB [0 ] if state . SB else B . shape
303368
304369 if len (input_shape ) == 3 :
305370 output_shape = (input_shape [0 ], input_shape [1 ], shapeB [0 ])
306371 else :
307372 output_shape = (input_shape [0 ], shapeB [0 ])
308373
309374 # 3. Matmul
310- C32A , SA = F .transform (CA , "col32" )
311- out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
312- # we apply the fused bias here
375+ if using_igemmlt :
376+ C32A , SA = F .transform (CA , "col32" )
377+ out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
378+ if bias is None or bias .dtype == torch .float16 :
379+ # we apply the fused bias here
380+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
381+ output = output .to (A .dtype )
382+ else : # apply bias separately
383+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = None )
384+ output = output .to (A .dtype ).add_ (bias )
313385
314- if bias is None or bias .dtype == torch .float16 :
315- output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
316- output = output .to (A .dtype )
317- else : # apply bias separately
318- output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = None )
319- output = output .to (A .dtype ).add_ (bias )
386+ else :
387+ A_wo_outliers = A .clone ()
388+ if state .idx is not None :
389+ A_wo_outliers [:, state .idx .long ()] = 0
390+ output = torch .nn .functional .linear (A_wo_outliers , state .CB .to (A .dtype ))
391+ output = output .mul_ (state .SCB .unsqueeze (0 ).mul (1.0 / 127.0 ))
392+ if bias is not None :
393+ output = output .add_ (bias )
320394
321395 # 4. Mixed-precision decomposition matmul
322396 if coo_tensorA is not None and subA is not None :
@@ -337,14 +411,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
337411 ctx .tensor_states = (None , None )
338412 ctx .save_for_backward (None , None )
339413
340-
341- clone_func = torch .clone if len (output_shape ) == 3 else lambda x : x
414+ clone_func = torch .clone if len (output_shape ) == 3 else lambda x : x
342415 return clone_func (output .view (output_shape ))
343416
344417 @staticmethod
345418 def backward (ctx , grad_output ):
346419 if ctx .is_empty :
347- bias_grad = ( None if ctx .bias is None else torch .zeros_like (ctx .bias ) )
420+ bias_grad = None if ctx .bias is None else torch .zeros_like (ctx .bias )
348421 return torch .zeros_like (ctx .A ), torch .zeros_like (ctx .B ), None , bias_grad , None
349422 req_gradA , req_gradB , _ , req_gradBias , _ = ctx .needs_input_grad
350423 CAt , subA = ctx .tensors
@@ -359,9 +432,7 @@ def backward(ctx, grad_output):
359432
360433 # Cast grad_output to fp16
361434 if len (grad_output .shape ) == 3 :
362- grad_output = grad_output .reshape (
363- - 1 , grad_output .shape [- 1 ]
364- ).contiguous ()
435+ grad_output = grad_output .reshape (- 1 , grad_output .shape [- 1 ]).contiguous ()
365436
366437 Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
367438 if req_gradB :
@@ -376,17 +447,29 @@ def backward(ctx, grad_output):
376447 if state .CBt is not None :
377448 C32grad , Sgrad = F .transform (Cgrad , "col32" )
378449 if state .CxBt is None :
379- state .CxBt , state .SBt = F .transform (
380- state .CBt , to_order = formatB , transpose = True
381- )
450+ state .CxBt , state .SBt = F .transform (state .CBt , to_order = formatB , transpose = True )
382451 gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
383452 grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape ).to (ctx .dtype_A )
384453
385454 elif state .CB is not None :
386- CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1. / 127.0 ))
455+ CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
456+ grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
457+ elif state .CxB is not None :
458+
459+ if state .tile_indices is None :
460+ order , tile_size = state .formatB , state .get_tile_size ()
461+ transform = lambda x : F .transform (x .cuda (), from_order = "row" , to_order = order )[0 ].to (x .device )
462+ with torch .no_grad ():
463+ state .tile_indices = get_inverse_transform_indices (transform , tile_size ).to (state .CxB .device )
464+
465+ CB = (
466+ undo_layout (state .CxB , state .tile_indices )
467+ .to (ctx .dtype_A )
468+ .mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
469+ )
387470 grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
388471 else :
389- raise Exception (' State must contain either CBt or CB matrix for backward' )
472+ raise Exception (" State must contain either CBt or CB or CxB matrix for backward" )
390473
391474 return grad_A , grad_B , None , grad_bias , None
392475
0 commit comments