Skip to content

Commit 5919078

Browse files
Update train_gpt.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5c7925c commit 5919078

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

train_gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def quantize_state_dict_ternary(state_dict: dict[str, Tensor], threshold_scale:
425425
# all zeros
426426
scales[name] = torch.tensor(0.0)
427427
ternary[name] = torch.zeros_like(tt, dtype=torch.int8)
428-
stats["ternary_payload_bytes"] += tensor_nbytes(ternary[name])
428+
stats["ternary_payload_bytes"] += tensor_nbytes(ternary[name]) + tensor_nbytes(scales[name])
429429
continue
430430
thr = threshold_scale * max_abs
431431
s = max_abs if max_abs > 0 else 1.0

0 commit comments

Comments
 (0)