Skip to content

Fix compressed tensors moe test.#2076

Open
dmmolitor wants to merge 2 commits intovllm-project:mainfrom
dmmolitor:fix_moe_tests
Open

Fix compressed tensors moe test.#2076
dmmolitor wants to merge 2 commits intovllm-project:mainfrom
dmmolitor:fix_moe_tests

Conversation

@dmmolitor
Copy link
Copy Markdown
Contributor

@dmmolitor dmmolitor commented Mar 28, 2026

Fixes the previously failing tests/layers/vllm/test_compressed_tensors_moe.py

Start with a short description of what the PR does and how this is a change from
the past.

  • fixes tests that were previously failing due to outdated function signatures and incorrect weight and weight scale structure.
  • initializes the weights to test moe logic with non-trivial values
  • updated to use shared utility for moe reference math
  • moved a shared reference quantization utility to utils.py

Detailed Report of All the Changes

An investigation into the Mixture of Experts (MoE) tests revealed several initialization and implementation discrepancies. These findings and subsequent resolutions are based on the codebase at commit fe631969.

1. Missing Arguments in FusedMoEConfig and FusedMoEParallelConfig Initialization

At the current HEAD, executing the tests resulted in TypeError exceptions due to missing required positional arguments during configuration setup. Specifically, the following errors were observed:

  • FusedMoEParallelConfig: Missing pcp_size, pcp_rank, sp_size, and enable_eplb.
  • FusedMoEConfig: Missing intermediate_size_per_partition, num_logical_experts, activation, device, and routing_method.

Resolution: These issues were resolved by utilizing the get_moe_config utility to automate the configuration process:

moe = quant_config.get_moe_config(layer)

2. Resolving Initialization Errors in VllmCompressedTensorsW8A8Fp8MoEMethod

After addressing the configuration issues, the initialization of VllmCompressedTensorsW8A8Fp8MoEMethod failed because it expected four arguments instead of three.

  • Error: TypeError: VllmCompressedTensorsW8A8Fp8MoEMethod.__init__() missing 1 required positional argument: 'mesh'.

Resolution: The constructor call was updated to include both weight_quant and input_quant:

# Updated initialization logic for lines 133-184
method = VllmCompressedTensorsW8A8Fp8MoEMethod(
    weight_quant, input_quant, moe, mesh
)

3. Transitioning to apply_monolithic

The apply method in the test suite encountered a TypeError due to an unexpected keyword argument top_k.

The CompressedTensorsW8A8Fp8MoEMethod.apply method expects the following signature:

def apply(self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, 
          topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None)

However, because VllmCompressedTensorsW8A8Fp8MoEMethod has the is_monolithic property set to True, the apply_monolithic function should be used instead.

Resolution: The implementation was updated to use apply_monolithic with the following signature:

# Updated logic for lines 197-202
def apply_monolithic(self, layer: FusedMoE, x: torch.Tensor, 
                     router_logits: torch.Tensor) -> torch.Tensor:

4. Correcting Weight and Scale Formats in the Reference Implementation

Tests previously failed with an AttributeError indicating that VllmFusedMoE lacked a w3_weight attribute. This occurred because weights w1 and w3 are stored together in w13_weight.

Resolution: The test was updated to use the shared test_utils.ref_moe function, replacing the outdated custom reference implementation. Following this change, the tests pass successfully.

5. Implementing Non-Zero Weight Initialization

To ensure the mathematical correctness of the kernels, the tests were updated to use non-zero weights. A reference quantization function was moved from the tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py test file to a shared utils.py file to facilitate this initialization.

Tests

Tested affected unit tests.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: dmmolitor <dmolitor@google.com>
@dmmolitor dmmolitor requested a review from vipannalla as a code owner March 28, 2026 03:10
@kyuyeunk
Copy link
Copy Markdown
Collaborator

@dmmolitor, do you need review for this pr?

@dmmolitor
Copy link
Copy Markdown
Contributor Author

kyuyeunk, yes, hoping for a review.

@kyuyeunk kyuyeunk added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 30, 2026
@kyuyeunk
Copy link
Copy Markdown
Collaborator

kyuyeunk, yes, hoping for a review.

it's a bit hard to tell just from the code change. can you elaborate what was the bug & how this pr fixes it? maybe add it to the pr description.

@dmmolitor
Copy link
Copy Markdown
Contributor Author

I added a detailed description of all the issues encountered and the corresponding changes made. I could split the weight initialization change into a separate PR or commit if needed.

The /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py was previously ignored by buildkite.

Signed-off-by: Denali Molitor <dmolitor@google.com>
@dmmolitor
Copy link
Copy Markdown
Contributor Author

I updated .buildkite/pipeline_jax.yml#L123 to include this test in the CI tests.

top_k=topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size)
quant_config = VllmCompressedTensorsConfig(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, why is this no longer needed ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arguments for VllmCompressedTensorsW8A8Fp8MoEMethod changed from
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
to
method = VllmCompressedTensorsW8A8Fp8MoEMethod(weight_quant, input_quant,
moe, mesh)

so we just need the weight_quant and input_quant components. The new version just grabs these from the test model's quant_config, but we could explicitly overwrite them.

initialize_layer_weights(layer)
method.process_weights_after_loading(layer)

def unquantize_weight_for_ref(weight, scale):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pulling out to utils.py too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, although I have a slight preference for keeping here for now. I'm not sure that the .transpose(
1, 2).cpu() is super generalizable (depends on the result format) and without that, this function is essentially just weight*scale.

@kyuyeunk kyuyeunk enabled auto-merge (squash) April 3, 2026 19:45
- rm -f .coverage.part1.${TPU_VERSION:-tpu6e}
- |
.buildkite/scripts/run_in_docker.sh bash -c "python3 -m pytest -s -v -x /workspace/tpu_inference/tests/layers/vllm/backends/test_flash_attn.py /workspace/tpu_inference/tests/layers/vllm/test_awq.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference; if [ -f .coverage ]; then cp .coverage /tmp/hf_home/.coverage.part1.${TPU_VERSION:-tpu6e}; else exit 1; fi"
.buildkite/scripts/run_in_docker.sh bash -c "python3 -m pytest -s -v -x /workspace/tpu_inference/tests/layers/vllm/backends/test_flash_attn.py /workspace/tpu_inference/tests/layers/vllm/test_awq.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py /workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference; if [ -f .coverage ]; then cp .coverage /tmp/hf_home/.coverage.part1.${TPU_VERSION:-tpu6e}; else exit 1; fi"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly confused by this. You adding the moe test to the ignore list? Does this defeat the purpose?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants