Fix compressed tensors moe test.#2076
Conversation
Signed-off-by: dmmolitor <dmolitor@google.com>
|
@dmmolitor, do you need review for this pr? |
|
kyuyeunk, yes, hoping for a review. |
it's a bit hard to tell just from the code change. can you elaborate what was the bug & how this pr fixes it? maybe add it to the pr description. |
|
I added a detailed description of all the issues encountered and the corresponding changes made. I could split the weight initialization change into a separate PR or commit if needed. |
The /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py was previously ignored by buildkite. Signed-off-by: Denali Molitor <dmolitor@google.com>
|
I updated .buildkite/pipeline_jax.yml#L123 to include this test in the CI tests. |
| top_k=topk, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size) | ||
| quant_config = VllmCompressedTensorsConfig( |
There was a problem hiding this comment.
Sorry, why is this no longer needed ?
There was a problem hiding this comment.
The arguments for VllmCompressedTensorsW8A8Fp8MoEMethod changed from
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
to
method = VllmCompressedTensorsW8A8Fp8MoEMethod(weight_quant, input_quant,
moe, mesh)
so we just need the weight_quant and input_quant components. The new version just grabs these from the test model's quant_config, but we could explicitly overwrite them.
| initialize_layer_weights(layer) | ||
| method.process_weights_after_loading(layer) | ||
|
|
||
| def unquantize_weight_for_ref(weight, scale): |
There was a problem hiding this comment.
Worth pulling out to utils.py too?
There was a problem hiding this comment.
We could, although I have a slight preference for keeping here for now. I'm not sure that the .transpose(
1, 2).cpu() is super generalizable (depends on the result format) and without that, this function is essentially just weight*scale.
| - rm -f .coverage.part1.${TPU_VERSION:-tpu6e} | ||
| - | | ||
| .buildkite/scripts/run_in_docker.sh bash -c "python3 -m pytest -s -v -x /workspace/tpu_inference/tests/layers/vllm/backends/test_flash_attn.py /workspace/tpu_inference/tests/layers/vllm/test_awq.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference; if [ -f .coverage ]; then cp .coverage /tmp/hf_home/.coverage.part1.${TPU_VERSION:-tpu6e}; else exit 1; fi" | ||
| .buildkite/scripts/run_in_docker.sh bash -c "python3 -m pytest -s -v -x /workspace/tpu_inference/tests/layers/vllm/backends/test_flash_attn.py /workspace/tpu_inference/tests/layers/vllm/test_awq.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference; if [ -f .coverage ]; then cp .coverage /tmp/hf_home/.coverage.part1.${TPU_VERSION:-tpu6e}; else exit 1; fi" |
There was a problem hiding this comment.
Slightly confused by this. You adding the moe test to the ignore list? Does this defeat the purpose?
Fixes the previously failing tests/layers/vllm/test_compressed_tensors_moe.py
Start with a short description of what the PR does and how this is a change from
the past.
Detailed Report of All the Changes
An investigation into the Mixture of Experts (MoE) tests revealed several initialization and implementation discrepancies. These findings and subsequent resolutions are based on the codebase at commit
fe631969.1. Missing Arguments in
FusedMoEConfigandFusedMoEParallelConfigInitializationAt the current
HEAD, executing the tests resulted inTypeErrorexceptions due to missing required positional arguments during configuration setup. Specifically, the following errors were observed:FusedMoEParallelConfig: Missingpcp_size,pcp_rank,sp_size, andenable_eplb.FusedMoEConfig: Missingintermediate_size_per_partition,num_logical_experts,activation,device, androuting_method.Resolution: These issues were resolved by utilizing the
get_moe_configutility to automate the configuration process:2. Resolving Initialization Errors in
VllmCompressedTensorsW8A8Fp8MoEMethodAfter addressing the configuration issues, the initialization of
VllmCompressedTensorsW8A8Fp8MoEMethodfailed because it expected four arguments instead of three.TypeError: VllmCompressedTensorsW8A8Fp8MoEMethod.__init__() missing 1 required positional argument: 'mesh'.Resolution: The constructor call was updated to include both
weight_quantandinput_quant:3. Transitioning to
apply_monolithicThe
applymethod in the test suite encountered aTypeErrordue to an unexpected keyword argumenttop_k.The
CompressedTensorsW8A8Fp8MoEMethod.applymethod expects the following signature:However, because
VllmCompressedTensorsW8A8Fp8MoEMethodhas theis_monolithicproperty set toTrue, theapply_monolithicfunction should be used instead.Resolution: The implementation was updated to use
apply_monolithicwith the following signature:4. Correcting Weight and Scale Formats in the Reference Implementation
Tests previously failed with an
AttributeErrorindicating thatVllmFusedMoElacked aw3_weightattribute. This occurred because weightsw1andw3are stored together inw13_weight.Resolution: The test was updated to use the shared
test_utils.ref_moefunction, replacing the outdated custom reference implementation. Following this change, the tests pass successfully.5. Implementing Non-Zero Weight Initialization
To ensure the mathematical correctness of the kernels, the tests were updated to use non-zero weights. A reference quantization function was moved from the tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py test file to a shared
utils.pyfile to facilitate this initialization.Tests
Tested affected unit tests.
Checklist
Before submitting this PR, please make sure: