Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Iterator, Literal

import torch
from compressed_tensors.distributed import wait_for_comms
from compressed_tensors.modeling.kvcache import QuantizedKVCache
from compressed_tensors.offload.dist_utils import as_broadcastable, is_distributed
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -47,7 +46,7 @@
from llmcompressor.observers.base import Observer
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.sentinel import Sentinel
from llmcompressor.utils import get_high_precision
from llmcompressor.utils import get_high_precision, wait_for_comms
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import get_module_to_name_dict

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict, Optional, Tuple, Union

import torch
from compressed_tensors.distributed import greedy_bin_packing, wait_for_comms
from compressed_tensors.offload.dist_utils import as_broadcastable, is_distributed
from compressed_tensors.quantization import (
QuantizationConfig,
Expand Down Expand Up @@ -32,6 +31,7 @@
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
from llmcompressor.sentinel import Sentinel
from llmcompressor.utils import greedy_bin_packing, wait_for_comms
from llmcompressor.utils.metric_logging import CompressionLogger

__all__ = ["GPTQModifier"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List, Optional, Set, Union

import torch
from compressed_tensors.distributed import wait_for_comms
from compressed_tensors.modeling import (
IMPL_ATTR,
KV_CACHE_ATTR,
Expand Down Expand Up @@ -43,6 +42,7 @@
from llmcompressor.utils import (
targets_embeddings,
untie_word_embeddings,
wait_for_comms,
)

__all__ = ["QuantizationMixin"]
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/transform/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.distributed as dist
from compressed_tensors.distributed import wait_for_comms
from compressed_tensors.offload import update_offload_parameter
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.utils import match_modules_set, match_named_modules
Expand All @@ -18,6 +17,7 @@
get_layer_mappings_from_architecture,
handle_mapping_resolution_errors,
)
from llmcompressor.utils.dist import wait_for_comms
from llmcompressor.utils.pytorch.module import get_module_to_name_dict

MINIMUM_SMOOTHING_SCALE = 1e-5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch.distributed as dist
from compressed_tensors import ModelCompressor, SparsityCompressionConfig
from compressed_tensors.config import CompressionFormat
from compressed_tensors.distributed import is_source_process
from compressed_tensors.offload import from_accelerate, to_accelerate
from compressed_tensors.offload import from_accelerate, is_rank0, to_accelerate
from compressed_tensors.utils import deprecated
from loguru import logger
from transformers import PreTrainedModel
Expand Down Expand Up @@ -74,7 +73,7 @@ def save_pretrained_wrapper(
# convert to accelerate offloaded for optimal saving with transformers
to_accelerate(model)

if is_source_process():
if is_rank0():
# save model structure
original_save_fn.__get__(model, model_class)(save_directory, **kwargs)

Expand Down
39 changes: 23 additions & 16 deletions src/llmcompressor/utils/dist.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from typing import Hashable, TypeVar
from typing import Callable, Hashable, TypeVar

from compressed_tensors.distributed import (
greedy_bin_packing as _greedy_bin_packing,
)
from compressed_tensors.distributed import (
wait_for_comms as _wait_for_comms,
)
from compressed_tensors.utils.helpers import deprecated
import torch.distributed as dist

T = TypeVar("T", bound=Hashable)


@deprecated("compressed_tensors.distributed.assign::greedy_bin_packing")
def greedy_bin_packing(*args, **kwargs) -> tuple[list[T], list[list[T]], dict[T, int]]:
def greedy_bin_packing(
items: list[T],
num_bins: int,
item_weight_fn: Callable[[T], float] = lambda x: 1,
) -> tuple[list[T], list[list[T]], dict[T, int]]:
"""Distribute items across bins using a greedy bin-packing heuristic.

Items are sorted by weight in descending order, then each item is
Expand All @@ -29,11 +26,19 @@ def greedy_bin_packing(*args, **kwargs) -> tuple[list[T], list[list[T]], dict[T,
the list of items assigned to that bin.
- item_to_bin: mapping from each item to its assigned bin index.
"""
return _greedy_bin_packing(*args, **kwargs)


@deprecated("compressed_tensors.distributed.utils::wait_for_comms")
def wait_for_comms(*args, **kwargs) -> None:
items.sort(key=item_weight_fn, reverse=True)
bin_to_items: list[list[T]] = [[] for _ in range(num_bins)]
item_to_bin: dict[T, int] = dict()
bin_weights: list[float] = [0 for _ in range(num_bins)]
for item in items:
target_bin = bin_weights.index(min(bin_weights))
bin_to_items[target_bin].append(item)
item_to_bin[item] = target_bin
bin_weights[target_bin] += item_weight_fn(item)
return items, bin_to_items, item_to_bin


def wait_for_comms(pending_comms: list[dist.Work]) -> None:
"""Block until all pending async distributed operations complete.

Calls ``wait()`` on each work handle, then clears the list in-place
Expand All @@ -44,4 +49,6 @@ def wait_for_comms(*args, **kwargs) -> None:
``async_op=True``). The list is cleared after all operations
have completed.
"""
return _wait_for_comms(*args, **kwargs)
for comm in list(pending_comms):
comm.wait()
pending_comms.clear()
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import pytest
import torch
import torch.distributed as dist
from compressed_tensors.distributed import wait_for_comms
from compressed_tensors.quantization import QuantizationArgs

from llmcompressor.observers.min_max import StaticMinMaxObserver
from llmcompressor.utils.dist import wait_for_comms
from tests.testing_utils import requires_gpu

# initialize process group when running under torchrun
Expand Down
Loading