Skip to content

[Distributed] Extend QuantizationModifier to support distributed activation calibration#2391

Open
Etelis wants to merge 9 commits intovllm-project:mainfrom
Etelis:feature/quantization-modifier-ddp
Open

[Distributed] Extend QuantizationModifier to support distributed activation calibration#2391
Etelis wants to merge 9 commits intovllm-project:mainfrom
Etelis:feature/quantization-modifier-ddp

Conversation

@Etelis
Copy link
Contributor

@Etelis Etelis commented Feb 22, 2026

Closes #2220

Adds DDP support to QuantizationModifier for activation observer synchronization across multiple GPUs during calibration.

At SEQUENTIAL_EPOCH_END and CALIBRATION_EPOCH_END, activation observer min/max values are all-reduced across ranks. Scale/zp are then recomputed from the global statistics so all ranks have identical quantization parameters.

Changes

  • Add synchronize(), recompute_qparams(), recompute_global_scale() to Observer base class
  • Add sync_activation_observers() to QuantizationMixin (shared by QuantizationModifier and GPTQModifier)
  • Batch all async dist.all_reduce operations and wait once, matching GPTQ DDP pattern
  • Remove recompute_qparams_from_observer from calibration.py (now encapsulated in Observer methods)
  • Align distributed example with existing DDP patterns (init_dist, get_rank_partition)

Runtime & Evaluation Results

Model: Meta-Llama-3-8B-Instruct, W8A8 (static input activations), 256 calibration samples

Setup Quantization Time Speedup Perplexity (wikitext)
1x A100-80GB (baseline) 270.1s 1.0x 37.51
2x A100-80GB (DDP) 138.0s 1.96x 37.64
4x A100-80GB (DDP) 72.1s 3.75x 36.40
  • Near-linear scaling with GPU count
  • Perplexity matches across all configurations, confirming activation observer sync produces correct quantization parameters

Test plan

  • Unit tests: pytest tests/llmcompressor/utils/test_distributed.py (8 tests)
  • Multi-GPU tests: torchrun --nproc_per_node=2 -m pytest tests/llmcompressor/modifiers/quantization/test_quantization_ddp.py (2 tests)
  • End-to-end quantization + eval on 1/2/4 GPUs (see results table above)

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Etelis, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the QuantizationModifier by integrating robust distributed training support. It introduces mechanisms to efficiently partition weight calibration tasks and synchronize activation observer statistics across multiple GPUs, thereby enabling large language model quantization in distributed environments. The changes are designed to be backward compatible, maintaining existing functionality for single-device operations.

Highlights

  • Distributed Quantization Support: The QuantizationModifier now supports Distributed Data Parallel (DDP) environments, enabling quantization workflows to scale across multiple GPUs.
  • Partitioned Weight Calibration: Weight calibration is distributed across ranks using a greedy bin-packing algorithm based on weight size. Each rank calibrates its assigned subset of modules and broadcasts the resulting quantization parameters (global_scale, scale, zero_point) to all other ranks.
  • Synchronized Activation Observers: Activation observer statistics (min/max values) are now synchronized across all DDP ranks at SEQUENTIAL_EPOCH_END and CALIBRATION_EPOCH_END events using all-reduce operations. This ensures all ranks compute identical quantization parameters for activations.
  • Backward Compatibility: All distributed operations are guarded by is_distributed() checks, ensuring that the modifier's behavior remains unchanged in single-GPU or non-distributed setups.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/big_models_with_sequential_onloading/llama3_8b_w8a8_distributed.py
    • Added a new example script demonstrating distributed W8A8 quantization for the Llama-3-8B-Instruct model.
  • src/llmcompressor/modifiers/quantization/calibration.py
    • Exported the new recompute_qparams_from_observer function.
    • Implemented recompute_qparams_from_observer to recalculate quantization parameters from observer statistics, particularly useful after distributed synchronization.
  • src/llmcompressor/modifiers/quantization/quantization/base.py
    • Imported new distributed utility functions to facilitate DDP operations.
    • Updated the QuantizationModifier docstring to reflect its new DDP capabilities.
    • Refactored the on_start method to conditionally execute single-process or distributed weight calibration.
    • Added _calibrate_weights_distributed method to manage DDP-specific weight calibration, including module partitioning, global scale computation, and broadcasting of results.
    • Introduced _sync_activation_observers to perform all-reduce on activation observer min/max values and recompute quantization parameters.
    • Integrated _sync_activation_observers into the on_event method, triggering synchronization at SEQUENTIAL_EPOCH_END and CALIBRATION_EPOCH_END.
  • src/llmcompressor/utils/init.py
    • Imported the newly added distributed utility module.
  • src/llmcompressor/utils/distributed.py
    • Added a new utility module for distributed processing.
    • Implemented is_distributed, get_rank, and get_world_size functions for DDP environment detection and information retrieval.
    • Developed _compute_rank_assignments for greedy bin-packing of modules based on weight size to balance workload across ranks.
    • Provided partition_modules_by_weight_size to return the subset of modules assigned to the current rank.
    • Created build_module_to_rank_map to establish a consistent mapping of modules to ranks.
    • Implemented broadcast_module_parameter to broadcast module parameters from a source rank to all other ranks, supporting CPU-offloaded parameters.
    • Added all_reduce_min and all_reduce_max functions for distributed minimum and maximum aggregation.
  • tests/llmcompressor/modifiers/quantization/test_quantization_ddp.py
    • Added new multi-GPU tests to verify the correctness of all_reduce_min and all_reduce_max operations.
    • Included tests to confirm that synchronized quantization parameters are identical across all DDP ranks.
  • tests/llmcompressor/utils/test_distributed.py
    • Added unit tests for the new distributed utility functions, covering non-distributed behavior, module partitioning logic, and all-reduce operations.
Activity
  • Unit tests passed locally (5/5).
  • Multi-GPU tests passed on 2x A100 (2/2).
  • An end-to-end oneshot() run with nm-testing/tinysmokellama-3.2 on 2x A100 successfully quantized 42 modules.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the documentation Improvements or additions to documentation label Feb 22, 2026
@mergify
Copy link
Contributor

mergify bot commented Feb 22, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 22, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces distributed support for the QuantizationModifier, enabling weight calibration and activation observer synchronization across multiple GPUs. This is a significant improvement for scaling quantization to large models. The implementation uses a greedy bin-packing algorithm for load balancing weight calibration, which is a solid choice. However, the current approach to synchronization involves a large number of individual collective communication calls (all-reduces and broadcasts) within loops, which will likely become a performance bottleneck due to network latency. Additionally, there are a few issues with device indexing in multi-node environments that should be addressed to ensure robustness.

Comment on lines +174 to +208
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
for base_name in ("input", "output", "q", "k", "v"):
observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
continue

# all-reduce accumulated min/max across ranks
if (
hasattr(observer, "past_min_vals")
and observer.past_min_vals is not None
):
observer.past_min_vals = all_reduce_min(observer.past_min_vals)
if (
hasattr(observer, "past_max_vals")
and observer.past_max_vals is not None
):
observer.past_max_vals = all_reduce_max(observer.past_max_vals)

# all-reduce global min/max (TENSOR_GROUP strategy)
if (
hasattr(observer, "past_global_min_vals")
and observer.past_global_min_vals is not None
):
observer.past_global_min_vals = all_reduce_min(
observer.past_global_min_vals
)
if (
hasattr(observer, "past_global_max_vals")
and observer.past_global_max_vals is not None
):
observer.past_global_max_vals = all_reduce_max(
observer.past_global_max_vals
)

recompute_qparams_from_observer(module, base_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _sync_activation_observers performs multiple all_reduce operations per module inside a nested loop. For a typical transformer model, this can result in hundreds or even thousands of small collective communication calls. In distributed settings, the latency overhead of many small calls is much higher than a single large call.

Consider aggregating all tensors that need reduction into a single list, concatenating them into one or two large buffers (e.g., one for MIN and one for MAX), performing a single all_reduce on each buffer, and then unpacking the results back into the observers. This will significantly improve performance on high-latency networks.

return

# NCCL requires each rank to use its own GPU
device = torch.device(f"cuda:{dist.get_rank()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dist.get_rank() returns the global rank of the process. In multi-node environments, this rank will exceed the number of GPUs available on a single node (e.g., rank 8 on the second node of an 8-GPU cluster). Using the global rank as a CUDA device index will result in an 'invalid device ordinal' error. Use torch.cuda.current_device() instead to ensure the correct local GPU is targeted.

Suggested change
device = torch.device(f"cuda:{dist.get_rank()}")
device = torch.device(torch.cuda.current_device())

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting the CUDA device using the global rank will fail in multi-node setups where the rank exceeds the number of GPUs per node. It is standard practice to use the local rank for device assignment.

Suggested change
torch.cuda.set_device(rank)
torch.cuda.set_device(rank % torch.cuda.device_count())

Add shared utility functions for multi-GPU weight calibration and
activation observer synchronization. All functions are no-ops when
torch.distributed is not initialized.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Add a helper function to recompute scale and zero_point from an
observer's accumulated min/max after DDP all-reduce synchronization.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Refactor QuantizationModifier.on_start to support distributed weight
calibration. Each rank calibrates a subset of modules (assigned by
greedy bin-packing on weight size) and broadcasts results to all ranks.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Use each rank's own GPU device for NCCL broadcast instead of the
module's execution device, which may be CPU or shared across ranks
when the model is not GPU-resident.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the feature/quantization-modifier-ddp branch from 61c255a to 72ed4b2 Compare February 22, 2026 12:13
@mergify mergify bot removed the needs-rebase label Feb 22, 2026
@Etelis Etelis force-pushed the feature/quantization-modifier-ddp branch from 72ed4b2 to 9975edc Compare February 22, 2026 12:15
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend letting @GOavi101 handle parallelized weight quantization, and instead focusing on parallelized activation quantization.

Once the requested changes have been made, please add a table of runtime and eval results, similar to GPTQ

@@ -0,0 +1,194 @@
"""
Distributed utilities for multi-GPU (DDP) calibration and optimization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidate with src/llmcompressor/utils/dist.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Deleted distributed.py entirely — all functions were either for weight partitioning (removed) or all_reduce wrappers (inlined into Observer.synchronize()). No functions needed to move to dist.py.

Comment on lines +135 to +137
# fuse global_scales (all ranks, idempotent)
for module in model.modules():
update_fused_layer_weight_global_scales(module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part should be tested with NVFP4. The problem is that there may be cases where submodules which need to be fused are assigned to different ranks. This is why I suggested breaking this out for @GOavi101 to focus on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is to do the weight calibration independently on each rank for now, and focus on the activation calibration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Weight calibration now runs identically on every rank — no partitioning or broadcasting. PR focuses exclusively on activation observer synchronization.

return tensor

device = tensor.device
if device.type == "cpu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the tensor need to be moved to the gpu?
I also don't see observer values being on the cpu as a common case, I'm wondering when you ever saw this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. The CPU-to-GPU movement and the all_reduce wrapper are both gone — Observer.synchronize() now calls dist.all_reduce directly on the observer tensors (which are always on GPU in practice).

update_offload_parameter(module, param_name, tensor)


def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code would be a bit clearer if we didn't break this function out, and instead just used dist.all_reduce directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed all_reduce_min/all_reduce_max wrappers. Observer.synchronize() calls dist.all_reduce directly with the appropriate ReduceOp.


# all-reduce accumulated min/max across ranks
if (
hasattr(observer, "past_min_vals")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having many hasattr calls, which is not very maintainable, consider implementing a synchronize() method on the observers directly. I also think that you need to synchronize global scales.

In addition, having many synchronization ops has lots of runtime cost. Consider implementing synchronize() with a return value of the comms, similar to the GPTQ implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added synchronize() on the Observer base class that returns a list of dist.Work handles (matching the GPTQ async pattern). All all_reduce ops are batched and waited on once via wait_for_comms. Also added recompute_global_scale() and recompute_qparams() to encapsulate the recomputation from accumulated state. Memoryless observers (no past_* attributes) return empty list/None automatically via getattr with defaults.

if not self.ended_:
self.on_end(state, None)

def _sync_activation_observers(self, model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you'll want to implement this method on the QuantizationMixin, that way DDP activation calibration logic can be shared with other modifiers like GPTQModifier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. sync_activation_observers() is now on QuantizationMixin, so both QuantizationModifier and GPTQModifier can use it.

calibrate_activations(module, value_states, base_name="v")


def recompute_qparams_from_observer(module: Module, base_name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function not redundant with call_observer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly redundant — call_observer needs a tensor value to run the forward pass through the observer, while recompute_qparams_from_observer recomputes scale/zp from already-accumulated past_min_vals/past_max_vals state (needed after DDP sync). Removed recompute_qparams_from_observer and moved the logic into Observer.recompute_qparams() and Observer.recompute_global_scale() methods instead.

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use init_distributed util

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Example now uses init_dist() from compressed_tensors.offload, along with load_offloaded_model(), get_rank_partition(), and dispatch_model() — matching the existing llama3_ddp_example.py pattern.

Remove distributed weight calibration (partition, broadcast, rank
assignment) and focus exclusively on activation observer synchronization.

Key changes:
- Add synchronize(), recompute_qparams(), recompute_global_scale()
  to Observer base class for clean DDP interface
- Move sync_activation_observers() to QuantizationMixin for reuse
  by both QuantizationModifier and GPTQModifier
- Batch all async all_reduce ops and wait once, matching GPTQ pattern
- Delete distributed.py (consolidated into Observer methods + dist.py)
- Remove recompute_qparams_from_observer from calibration.py
- Align example with existing DDP patterns (init_dist, get_rank_partition)
- Update unit and multi-GPU tests for new observer-based sync

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Contributor Author

Etelis commented Feb 24, 2026

All review comments have been addressed:

  • Removed distributed weight calibration — weight calibration runs identically on each rank
  • Focus is exclusively on activation observer synchronization
  • synchronize(), recompute_qparams(), recompute_global_scale() added to Observer base class
  • sync_activation_observers() moved to QuantizationMixin (shared with GPTQModifier)
  • All async all_reduce ops batched + waited once via wait_for_comms() (GPTQ pattern)
  • Deleted distributed.py, removed all_reduce_min/all_reduce_max wrappers, removed CPU-to-GPU tensor movement
  • Removed recompute_qparams_from_observer from calibration.py
  • Example uses init_dist(), get_rank_partition(), load_offloaded_model()

Runtime & eval results added to PR description (Llama-3-8B, W8A8 static activations, 256 samples, A100-80GB):

Setup Time Speedup Perplexity (wikitext)
1x GPU 270.1s 1.0x 37.51
2x GPU (DDP) 138.0s 1.96x 37.64
4x GPU (DDP) 72.1s 3.75x 36.40

Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code looks really great, the benchmarks look great as well. A couple notes from me:

  1. The fact that the 4x DDP setup does not increase perplexity gives me confidence that syncing once per epoch (rather than once per batch) is good enough, nice work.
  2. From your speedup benchmarks, it seems like repeating work (calculate_q/gparams) across ranks is not too much of a cost. That seems to match expectations as well, nice work.

I'll make sure this code gets merged as part of the next LLM Compressor release.

Comment on lines +147 to +152
for attr, op in [
("past_min_vals", dist.ReduceOp.MIN),
("past_max_vals", dist.ReduceOp.MAX),
("past_global_min_vals", dist.ReduceOp.MIN),
("past_global_max_vals", dist.ReduceOp.MAX),
]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that I think this approach is more elegant than reimplementing for each subclass

Copy link
Collaborator

@HDCharles HDCharles Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont we need to do this for every subclass? this strategy only makes sense for an single observer: static_minmax

memoryless_minmax and memoryless_mse it also works i guess, because they don't have any of those values

minmax (moving average) and mse (moving average) - it makes no sense, you probably need to average across ranks, though it wouldn't be hard in theory to do:

$$f(x_0,...,x_{n-1}) = w*\sum_{i=0}^{n-1} x_i(1-w)^{n-1-i}$$ (rank 0 avg)
$$f(x_n,...,x_{2n-1}=w*\sum_{i=n}^{2n-1} x_i(1-w)^{2n-1-i}$$ (rank 1 avg)
$$f(x_0,...,x_{2n-1}) = w*\sum_{i=0}^{2n-1} x_i(1-w)^{2n-1-i}$$ (alldata avg)
$$= (1-w)^{n} w \sum_{i=0}^{n-1} x_i(1-w)^{n-1-i} + w\sum_{i=n}^{2n-1} x_i(1-w)^{2n-1-i} $$
$$= f(x_0,...,x_{n-1})(1-w)^n + f(x_n,...,x_{2n-1})$$ accumulate in terms of rank averages

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's just average for now

self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need to add these to GPTQ and AWQ, right?

@kylesayrs kylesayrs self-requested a review February 24, 2026 20:15
@kylesayrs kylesayrs changed the title [Distributed] Extend QuantizationModifier to support weight-parallel optimization [Distributed] Extend QuantizationModifier to support distributed activation calibration Feb 26, 2026
]:
val = getattr(self, attr, None)
if val is not None:
comms.append(dist.all_reduce(val, op=op, async_op=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need the fp8 trick here from GPTQ base.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance Refactor] Extend modifiers to support weight-parallel optimization - QuantizationModifier

4 participants