Skip to content

Commit d77f7bf

Browse files
address reviews
1 parent 332a7d5 commit d77f7bf

File tree

4 files changed

+37
-19
lines changed

4 files changed

+37
-19
lines changed

guides/ipynb/writing_quantization_compatible_layers.ipynb

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,12 @@
140140
" )\n",
141141
" scale = ops.squeeze(scale, axis=0)\n",
142142
"\n",
143+
" kernel_shape = self._kernel.shape\n",
144+
"\n",
143145
" del self._kernel\n",
144146
"\n",
145147
" # Allocate INT8 variables. Discussed in the next section.\n",
146-
" self._int8_build(kernel_shape=self._kernel.shape)\n",
148+
" self._int8_build(kernel_shape)\n",
147149
"\n",
148150
" self._kernel.assign(quantized_kernel)\n",
149151
" self.scale.assign(scale)\n",
@@ -246,6 +248,10 @@
246248
"quantized variables allocated in `_int8_build(...)` and de-scales the output\n",
247249
"back to floating-point.\n",
248250
"\n",
251+
"The base `keras.Layer` class automatically dispatches to this method when the\n",
252+
"layer is quantized. Your regular call() method will be used for the\n",
253+
"full-precision forward pass.\n",
254+
"\n",
249255
"The INT8 path mirrors the float computation `y = x * w` but performs:\n",
250256
"\n",
251257
"1. Elementwise multiply using the quantized weight.\n",
@@ -320,7 +326,7 @@
320326
"\n",
321327
" del self._kernel\n",
322328
"\n",
323-
" self._int8_build(kernel_shape=kernel_shape)\n",
329+
" self._int8_build(kernel_shape)\n",
324330
"\n",
325331
" self._kernel.assign(quantized_kernel)\n",
326332
" self.scale.assign(scale)\n",
@@ -587,7 +593,7 @@
587593
"\n",
588594
" del self._kernel\n",
589595
"\n",
590-
" self._int8_build(kernel_shape=kernel_shape)\n",
596+
" self._int8_build(kernel_shape)\n",
591597
"\n",
592598
" self._kernel.assign(quantized_kernel)\n",
593599
" self.scale.assign(scale)\n",
@@ -717,8 +723,8 @@
717723
" - The axis you packed along (e.g., `_int4_pack_axis`).\n",
718724
" - The original (unpacked) length on that axis (e.g., `_original_input_dim` or\n",
719725
" `_original_length_along_pack_axis`).\n",
720-
"- In `call(...)`, compute with the quantized buffers and de-scale back to float\n",
721-
" at the end, wherever possible. This allows you to leverage optimized\n",
726+
"- In quantized call hooks, compute with the quantized buffers and de-scale back\n",
727+
" to float at the end, wherever possible. This allows you to leverage optimized\n",
722728
" low-precision kernels (e.g., cuBLAS INT8 GEMM).\n",
723729
"\n",
724730
"- INT4 specifics (packed nibbles)\n",

guides/md/writing_quantization_compatible_layers.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ def quantize(self, mode, **kwargs):
9999
)
100100
scale = ops.squeeze(scale, axis=0)
101101

102+
kernel_shape = self._kernel.shape
103+
102104
del self._kernel
103105

104106
# Allocate INT8 variables. Discussed in the next section.
105-
self._int8_build(kernel_shape=self._kernel.shape)
107+
self._int8_build(kernel_shape)
106108

107109
self._kernel.assign(quantized_kernel)
108110
self.scale.assign(scale)
@@ -172,6 +174,10 @@ The `_int8_call(...)` method implements a minimal INT8 forward path. It uses the
172174
quantized variables allocated in `_int8_build(...)` and de-scales the output
173175
back to floating-point.
174176

177+
The base `keras.Layer` class automatically dispatches to this method when the
178+
layer is quantized. Your regular call() method will be used for the
179+
full-precision forward pass.
180+
175181
The INT8 path mirrors the float computation `y = x * w` but performs:
176182

177183
1. Elementwise multiply using the quantized weight.
@@ -227,7 +233,7 @@ class SimpleScale(Layer):
227233

228234
del self._kernel
229235

230-
self._int8_build(kernel_shape=kernel_shape)
236+
self._int8_build(kernel_shape)
231237

232238
self._kernel.assign(quantized_kernel)
233239
self.scale.assign(scale)
@@ -288,8 +294,8 @@ print("SimpleScale INT8 sample:", y_int8[0].numpy())
288294

289295
<div class="k-default-codeblock">
290296
```
291-
SimpleScale FP32 sample: [-0.00756585 -0.0135909 -0.02137992 0.01047459]
292-
SimpleScale INT8 sample: [-0.00756123 -0.01362174 -0.02146736 0.01047461]
297+
SimpleScale FP32 sample: [ 0.00074363 -0.02807784 -0.0032404 -0.03456082]
298+
SimpleScale INT8 sample: [ 0.00074166 -0.0279077 -0.00322246 -0.03456089]
293299
```
294300
</div>
295301

@@ -446,7 +452,7 @@ class SimpleScale(Layer):
446452

447453
del self._kernel
448454

449-
self._int8_build(kernel_shape=kernel_shape)
455+
self._int8_build(kernel_shape)
450456

451457
self._kernel.assign(quantized_kernel)
452458
self.scale.assign(scale)
@@ -542,8 +548,8 @@ print("Loaded INT8 sample:", y_loaded[0].numpy())
542548

543549
<div class="k-default-codeblock">
544550
```
545-
SimpleScale INT8 sample: [ 0.02398201 -0.00298704 0.02251735 0.0029661 ]
546-
Loaded INT8 sample: [ 0.02398201 -0.00298704 0.02251735 0.0029661 ]
551+
SimpleScale INT8 sample: [-0.00047286 0.02926966 -0.00708966 0.03041461]
552+
Loaded INT8 sample: [-0.00047286 0.02926966 -0.00708966 0.03041461]
547553
548554
/Users/jyotindersingh/miniconda3/envs/keras-io-env-3.12/lib/python3.12/site-packages/keras/src/models/model.py:472: UserWarning: Layer InputLayer does not have a `quantize` method implemented.
549555
warnings.warn(str(e))
@@ -562,8 +568,8 @@ Here are concrete patterns you can reuse when making your own layers PTQ-friendl
562568
- The axis you packed along (e.g., `_int4_pack_axis`).
563569
- The original (unpacked) length on that axis (e.g., `_original_input_dim` or
564570
`_original_length_along_pack_axis`).
565-
- In `call(...)`, compute with the quantized buffers and de-scale back to float
566-
at the end, wherever possible. This allows you to leverage optimized
571+
- In quantized call hooks, compute with the quantized buffers and de-scale back
572+
to float at the end, wherever possible. This allows you to leverage optimized
567573
low-precision kernels (e.g., cuBLAS INT8 GEMM).
568574

569575
- INT4 specifics (packed nibbles)

guides/writing_quantization_compatible_layers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ def quantize(self, mode, **kwargs):
9696
)
9797
scale = ops.squeeze(scale, axis=0)
9898

99+
kernel_shape = self._kernel.shape
100+
99101
del self._kernel
100102

101103
# Allocate INT8 variables. Discussed in the next section.
102-
self._int8_build(kernel_shape=self._kernel.shape)
104+
self._int8_build(kernel_shape)
103105

104106
self._kernel.assign(quantized_kernel)
105107
self.scale.assign(scale)
@@ -171,6 +173,10 @@ def _int8_build(self, kernel_shape):
171173
quantized variables allocated in `_int8_build(...)` and de-scales the output
172174
back to floating-point.
173175
176+
The base `keras.Layer` class automatically dispatches to this method when the
177+
layer is quantized. Your regular call() method will be used for the
178+
full-precision forward pass.
179+
174180
The INT8 path mirrors the float computation `y = x * w` but performs:
175181
176182
1. Elementwise multiply using the quantized weight.
@@ -223,7 +229,7 @@ def quantize(self, mode, **kwargs):
223229

224230
del self._kernel
225231

226-
self._int8_build(kernel_shape=kernel_shape)
232+
self._int8_build(kernel_shape)
227233

228234
self._kernel.assign(quantized_kernel)
229235
self.scale.assign(scale)
@@ -429,7 +435,7 @@ def quantize(self, mode, **kwargs):
429435

430436
del self._kernel
431437

432-
self._int8_build(kernel_shape=kernel_shape)
438+
self._int8_build(kernel_shape)
433439

434440
self._kernel.assign(quantized_kernel)
435441
self.scale.assign(scale)
@@ -532,8 +538,8 @@ def load_own_variables(self, store):
532538
- The axis you packed along (e.g., `_int4_pack_axis`).
533539
- The original (unpacked) length on that axis (e.g., `_original_input_dim` or
534540
`_original_length_along_pack_axis`).
535-
- In `call(...)`, compute with the quantized buffers and de-scale back to float
536-
at the end, wherever possible. This allows you to leverage optimized
541+
- In quantized call hooks, compute with the quantized buffers and de-scale back
542+
to float at the end, wherever possible. This allows you to leverage optimized
537543
low-precision kernels (e.g., cuBLAS INT8 GEMM).
538544
539545
- INT4 specifics (packed nibbles)

simplescale_int8.keras

-10.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)