Skip to content

Commit 0c05fd1

Browse files
committed
add GGUF on Embedding layer (PR 2963 from main forge2)
lllyasviel#2963
1 parent cea47aa commit 0c05fd1

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

backend/operations.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,43 @@ def forward(self, x):
436436
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True)
437437
with main_stream_worker(weight, bias, signal):
438438
return torch.nn.functional.linear(x, weight, bias)
439+
440+
class Embedding(torch.nn.Embedding):
441+
def __init__(self, *args, **kwargs):
442+
kwargs['device'] = current_device
443+
super().__init__(*args, **kwargs)
444+
self.parameters_manual_cast = current_manual_cast_enabled
445+
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
446+
self.bias = None
447+
448+
def reset_parameters(self):
449+
self.bias = None
450+
return None
451+
452+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
453+
if hasattr(self, 'dummy'):
454+
computation_dtype = self.dummy.dtype
455+
if computation_dtype not in [torch.float16, torch.bfloat16]:
456+
# GGUF cast only supports 16bits otherwise super slow
457+
computation_dtype = torch.float16
458+
if prefix + 'weight' in state_dict:
459+
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
460+
self.weight.computation_dtype = computation_dtype
461+
del self.dummy
462+
else:
463+
if prefix + 'weight' in state_dict:
464+
self.weight = state_dict[prefix + 'weight']
465+
return
466+
467+
def _apply(self, fn, recurse=True):
468+
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
469+
setattr(self, k, utils.tensor2parameter(fn(p)))
470+
return self
471+
472+
def forward(self, x):
473+
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, skip_weight_dtype=True, skip_bias_dtype=True)
474+
with main_stream_worker(weight, bias, signal):
475+
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
439476

440477

441478
@contextlib.contextmanager

0 commit comments

Comments
 (0)