Skip to content

Commit f8236e4

Browse files
authored
missing device (#232)
1 parent ceeef3e commit f8236e4

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

quantize.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
515515

516516

517517
def replace_embedding_weight_only_grouped_int8_per_channel(
518-
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False
518+
module,
519+
device,
520+
bitwidth: int = 8,
521+
groupsize: Optional[int] = None,
522+
packed=False
519523
):
520524
for name, child in module.named_children():
521525
# print(f"name: {name}")
@@ -535,7 +539,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
535539
)
536540
else:
537541
replace_embedding_weight_only_grouped_int8_per_channel(
538-
child, bitwidth, groupsize, packed
542+
child, device, bitwidth, groupsize, packed
539543
)
540544

541545

0 commit comments

Comments
 (0)