|
5 | 5 | import jax.numpy as jnp |
6 | 6 | import numpy as np |
7 | 7 | import pytest |
| 8 | +from flax import nnx |
8 | 9 | from jax.sharding import Mesh |
9 | 10 |
|
10 | 11 | import sgl_jax.srt.kernels.quantized_matmul.kernel as quant_kernel |
@@ -101,7 +102,8 @@ def test_quantized_linear_offline_scale_formats(scale_format): |
101 | 102 | ) |
102 | 103 |
|
103 | 104 | ref_out = jnp.dot(x, w_fp.T) |
104 | | - out = quant_linear(x) |
| 105 | + out, bias = quant_linear(x) |
| 106 | + assert bias is None |
105 | 107 |
|
106 | 108 | _assert_close(f"Offline QuantizedLinear ({scale_format})", out, ref_out) |
107 | 109 |
|
@@ -198,7 +200,8 @@ def __init__(self, mesh): |
198 | 200 | ) |
199 | 201 |
|
200 | 202 | mesh = _create_single_device_mesh() |
201 | | - model = DummyModel(mesh) |
| 203 | + with jax.set_mesh(mesh): |
| 204 | + model = DummyModel(mesh) |
202 | 205 |
|
203 | 206 | class FakeModelConfig: |
204 | 207 | pass |
@@ -230,9 +233,108 @@ class FakeModelConfig: |
230 | 233 | assert model.proj.weight_scale.value.ndim == 1 |
231 | 234 |
|
232 | 235 |
|
| 236 | +def test_linear_return_contract_with_bias(): |
| 237 | + mesh = _create_single_device_mesh() |
| 238 | + x = jnp.ones((2, 64), dtype=jnp.bfloat16) |
| 239 | + |
| 240 | + with jax.set_mesh(mesh): |
| 241 | + linear = LinearBase( |
| 242 | + input_size=64, |
| 243 | + output_size=32, |
| 244 | + use_bias=True, |
| 245 | + mesh=mesh, |
| 246 | + kernel_axes=(None, None), |
| 247 | + params_dtype=jnp.bfloat16, |
| 248 | + scope_name="biased_proj", |
| 249 | + ) |
| 250 | + out, bias = linear(x) |
| 251 | + |
| 252 | + assert out.shape == (2, 32) |
| 253 | + assert bias is None |
| 254 | + |
| 255 | + with jax.set_mesh(mesh): |
| 256 | + quant_linear = QuantizedLinear.from_linear( |
| 257 | + linear, |
| 258 | + weight_dtype=jnp.int8, |
| 259 | + activation_dtype=None, |
| 260 | + is_static_input=False, |
| 261 | + ) |
| 262 | + q_out, q_bias = quant_linear(x) |
| 263 | + |
| 264 | + assert q_out.shape == (2, 32) |
| 265 | + assert q_bias is None |
| 266 | + |
| 267 | + |
| 268 | +def test_ignored_layers_only_skips_requested_paths(): |
| 269 | + class SelfAttn(nnx.Module): |
| 270 | + def __init__(self, mesh): |
| 271 | + self.q_proj = LinearBase( |
| 272 | + input_size=64, |
| 273 | + output_size=32, |
| 274 | + use_bias=False, |
| 275 | + mesh=mesh, |
| 276 | + kernel_axes=(None, None), |
| 277 | + params_dtype=jnp.bfloat16, |
| 278 | + scope_name="q_proj", |
| 279 | + ) |
| 280 | + self.o_proj = LinearBase( |
| 281 | + input_size=64, |
| 282 | + output_size=32, |
| 283 | + use_bias=False, |
| 284 | + mesh=mesh, |
| 285 | + kernel_axes=(None, None), |
| 286 | + params_dtype=jnp.bfloat16, |
| 287 | + scope_name="o_proj", |
| 288 | + ) |
| 289 | + |
| 290 | + class DummyBlock(nnx.Module): |
| 291 | + def __init__(self, mesh): |
| 292 | + self.self_attn = SelfAttn(mesh) |
| 293 | + |
| 294 | + class FakeModelConfig: |
| 295 | + pass |
| 296 | + |
| 297 | + def _make_config(ignored_layers): |
| 298 | + model_config = FakeModelConfig() |
| 299 | + model_config.quantization_config = type( |
| 300 | + "FakeQuantConfig", |
| 301 | + (), |
| 302 | + { |
| 303 | + "get_linear_rules": staticmethod( |
| 304 | + lambda: [ |
| 305 | + { |
| 306 | + "module_path": ".*", |
| 307 | + "weight_dtype": "int8", |
| 308 | + "activation_dtype": None, |
| 309 | + "weight_block_size": None, |
| 310 | + } |
| 311 | + ] |
| 312 | + ), |
| 313 | + "ignored_layers": ignored_layers, |
| 314 | + "weight_block_size": [128, 128], |
| 315 | + }, |
| 316 | + )() |
| 317 | + return model_config |
| 318 | + |
| 319 | + mesh = _create_single_device_mesh() |
| 320 | + with jax.set_mesh(mesh): |
| 321 | + model = DummyBlock(mesh) |
| 322 | + apply_linear_quantization(_make_config(["some_other_layer"]), model, is_static_input=False) |
| 323 | + assert isinstance(model.self_attn.q_proj, QuantizedLinear) |
| 324 | + assert isinstance(model.self_attn.o_proj, QuantizedLinear) |
| 325 | + |
| 326 | + with jax.set_mesh(mesh): |
| 327 | + model = DummyBlock(mesh) |
| 328 | + apply_linear_quantization(_make_config(["self_attn.o_proj"]), model, is_static_input=False) |
| 329 | + assert isinstance(model.self_attn.q_proj, QuantizedLinear) |
| 330 | + assert isinstance(model.self_attn.o_proj, LinearBase) |
| 331 | + |
| 332 | + |
233 | 333 | if __name__ == "__main__": |
234 | 334 | for fmt in ("per_channel", "block_channel", "block_quant"): |
235 | 335 | test_quantized_linear_offline_scale_formats(fmt) |
236 | 336 | test_xla_quantized_matmul_block_quant_all() |
237 | 337 | _assert_blockwise_tuning_fallback_uses_compatible_seed() |
238 | 338 | test_linear_rule_weight_block_size_override() |
| 339 | + test_linear_return_contract_with_bias() |
| 340 | + test_ignored_layers_only_skips_requested_paths() |
0 commit comments