@@ -436,6 +436,43 @@ def forward(self, x):
436436 weight , bias , signal = weights_manual_cast (self , x , weight_fn = dequantize_tensor , bias_fn = None , skip_bias_dtype = True )
437437 with main_stream_worker (weight , bias , signal ):
438438 return torch .nn .functional .linear (x , weight , bias )
439+
440+ class Embedding (torch .nn .Embedding ):
441+ def __init__ (self , * args , ** kwargs ):
442+ kwargs ['device' ] = current_device
443+ super ().__init__ (* args , ** kwargs )
444+ self .parameters_manual_cast = current_manual_cast_enabled
445+ self .dummy = torch .nn .Parameter (torch .empty (1 , device = current_device , dtype = current_dtype ))
446+ self .bias = None
447+
448+ def reset_parameters (self ):
449+ self .bias = None
450+ return None
451+
452+ def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ):
453+ if hasattr (self , 'dummy' ):
454+ computation_dtype = self .dummy .dtype
455+ if computation_dtype not in [torch .float16 , torch .bfloat16 ]:
456+ # GGUF cast only supports 16bits otherwise super slow
457+ computation_dtype = torch .float16
458+ if prefix + 'weight' in state_dict :
459+ self .weight = state_dict [prefix + 'weight' ].to (device = self .dummy .device )
460+ self .weight .computation_dtype = computation_dtype
461+ del self .dummy
462+ else :
463+ if prefix + 'weight' in state_dict :
464+ self .weight = state_dict [prefix + 'weight' ]
465+ return
466+
467+ def _apply (self , fn , recurse = True ):
468+ for k , p in self .named_parameters (recurse = False , remove_duplicate = True ):
469+ setattr (self , k , utils .tensor2parameter (fn (p )))
470+ return self
471+
472+ def forward (self , x ):
473+ weight , bias , signal = weights_manual_cast (self , x , weight_fn = dequantize_tensor , skip_weight_dtype = True , skip_bias_dtype = True )
474+ with main_stream_worker (weight , bias , signal ):
475+ return torch .nn .functional .embedding (x , weight , self .padding_idx , self .max_norm , self .norm_type , self .scale_grad_by_freq , self .sparse )
439476
440477
441478@contextlib .contextmanager
0 commit comments