@@ -241,16 +241,16 @@ def __init__(
241
241
self .impl = impl_builder .build (in_features , out_features , block_size = 128 , bias = bias is not None , dtype = dtype )
242
242
self .block_size = 128
243
243
self .fp8_dtype = fp8_dtype
244
- weight , scale , bias = self .create_weights (in_features , out_features , bias , dtype , device )
244
+ weight , weight_scale_inv , bias = self .create_weights (in_features , out_features , bias , dtype , device )
245
245
weight = torch .nn .Parameter (weight , requires_grad = False )
246
246
weight .weight_loader = self .weight_loader
247
- scale = torch .nn .Parameter (scale , requires_grad = False )
248
- scale .weight_loader = self .weight_loader
247
+ weight_scale_inv = torch .nn .Parameter (weight_scale_inv , requires_grad = False )
248
+ weight_scale_inv .weight_loader = self .weight_loader
249
249
if bias is not None :
250
250
bias = torch .nn .Parameter (bias , requires_grad = False )
251
251
bias .weight_loader = self .weight_loader
252
252
self .register_parameter ('weight' , weight )
253
- self .register_parameter ('scale ' , scale )
253
+ self .register_parameter ('weight_scale_inv ' , weight_scale_inv )
254
254
self .register_parameter ('bias' , bias )
255
255
256
256
self .in_features = in_features
@@ -302,27 +302,27 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
302
302
def create_weights (self , in_features : int , out_features : int , bias : bool , dtype : torch .dtype , device : torch .device ):
303
303
"""create weights."""
304
304
weight = torch .empty ((out_features , in_features ), dtype = self .fp8_dtype , device = device )
305
- scale = torch .empty ((div_up (out_features , self .block_size ), div_up (in_features , self .block_size )),
306
- dtype = torch .float32 ,
307
- device = device )
305
+ weight_scale_inv = torch .empty ((div_up (out_features , self .block_size ), div_up (in_features , self .block_size )),
306
+ dtype = torch .float32 ,
307
+ device = device )
308
308
if bias :
309
309
bias = torch .empty ((out_features , ), dtype = dtype , device = device )
310
310
else :
311
311
bias = None
312
- return weight , scale , bias
312
+ return weight , weight_scale_inv , bias
313
313
314
314
def update_weights (self ):
315
315
"""update weights."""
316
- weight , scale , bias = self .impl .update_weights (self .weight , self .scale , self .bias )
316
+ weight , weight_scale_inv , bias = self .impl .update_weights (self .weight , self .weight_scale_inv , self .bias )
317
317
weight = torch .nn .Parameter (weight , requires_grad = False )
318
318
self .weight .weight_loader = self .weight_loader
319
- scale = torch .nn .Parameter (scale , requires_grad = False )
320
- self .scale .weight_loader = self .weight_loader
319
+ weight_scale_inv = torch .nn .Parameter (weight_scale_inv , requires_grad = False )
320
+ self .weight_scale_inv .weight_loader = self .weight_loader
321
321
if bias is not None :
322
322
bias = torch .nn .Parameter (bias , requires_grad = False )
323
323
self .bias .weight_loader = self .weight_loader
324
324
self .register_parameter ('weight' , weight )
325
- self .register_parameter ('scale ' , scale )
325
+ self .register_parameter ('weight_scale_inv ' , weight_scale_inv )
326
326
self .register_parameter ('bias' , bias )
327
327
328
328
def forward (self , x ):
@@ -340,11 +340,11 @@ def forward(self, x):
340
340
if len (self .lora_adapters ) == 0 :
341
341
if self .dp_scatter :
342
342
_ , rank = get_tp_world_rank ()
343
- return self .impl .forward (x , self .weight , self .scale , self .bias , all_reduce , rank , tp_sizes )
343
+ return self .impl .forward (x , self .weight , self .weight_scale_inv , self .bias , all_reduce , rank , tp_sizes )
344
344
else :
345
- return self .impl .forward (x , self .weight , self .scale , self .bias , all_reduce )
345
+ return self .impl .forward (x , self .weight , self .weight_scale_inv , self .bias , all_reduce )
346
346
347
- out = self .impl .forward (x , self .weight , self .scale , self .bias , False )
347
+ out = self .impl .forward (x , self .weight , self .weight_scale_inv , self .bias , False )
348
348
for lora_adapter in self .lora_adapters .values ():
349
349
out = lora_adapter (x , out )
350
350
if all_reduce :
@@ -394,10 +394,10 @@ def __init__(self,
394
394
dp_gather = dp_gather )
395
395
self .weight .weight_loader = self .weight_loader
396
396
self .weight ._weight_type = 'qweight'
397
- self .scale .weight_loader = self .weight_loader
398
- self .scale ._weight_type = 'scales'
397
+ self .weight_scale_inv .weight_loader = self .weight_loader
398
+ self .weight_scale_inv ._weight_type = 'scales'
399
399
self .weight .weight_spliter = self .weight_spliter
400
- self .scale .weight_spliter = self .weight_spliter
400
+ self .weight_scale_inv .weight_spliter = self .weight_spliter
401
401
if self .bias is not None :
402
402
self .bias .weight_loader = self .weight_loader
403
403
self .bias .weight_spliter = self .weight_spliter
0 commit comments