Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/llmcompressor/entrypoints/model_free/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Iterable, Optional
from typing import Iterable, List, Optional, Union

import torch
import tqdm
from compressed_tensors.quantization import QuantizationScheme
from loguru import logger

from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
from llmcompressor.entrypoints.model_free.device_balancer import DeviceLoadBalancer
from llmcompressor.entrypoints.model_free.microscale import (
is_microscale_scheme,
)
Expand All @@ -31,7 +31,7 @@
validate_scheme,
)

__all__ = ["model_free_ptq"]
__all__ = ["model_free_ptq", "DeviceLoadBalancer"]


def model_free_ptq(
Expand All @@ -40,7 +40,7 @@ def model_free_ptq(
scheme: QuantizationScheme | str,
ignore: Iterable[str] = tuple(),
max_workers: int = 1,
device: Optional[torch.device | str] = None,
device: Optional[Union[torch.device, str, List[Union[torch.device, str]]]] = None,
):
"""
Quantize a model without the need for a model definition. This function operates on
Expand All @@ -51,12 +51,13 @@ def model_free_ptq(
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param max_workers: number of worker threads to process files with
:param device: gpu device to accelerate quantization with
:param device: gpu device to accelerate quantization with. Can be a single device
or a list of devices for multi-GPU support
"""
# validate arguments
model_files = get_checkpoint_files(model_stub)
scheme_name, scheme = validate_scheme(scheme)
device = gpu_if_available(device)
device_balancer = DeviceLoadBalancer(device)
validate_safetensors_index(model_files, scheme)

# 0. collect safetensors files, copy files
Expand All @@ -70,7 +71,9 @@ def model_free_ptq(
save_path = Path(save_directory) / file_path

if file_path.endswith("safetensors"):
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
jobs.append(
(job_fn, resolved_path, save_path, scheme, ignore, device_balancer)
)

else:
if is_weights_file(file_path):
Expand Down
102 changes: 102 additions & 0 deletions src/llmcompressor/entrypoints/model_free/device_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import functools
import inspect
from threading import Lock
from typing import List, Optional, Union

import torch
from loguru import logger

from llmcompressor.entrypoints.model_free.helpers import gpu_if_available

__all__ = ["DeviceLoadBalancer"]


class DeviceLoadBalancer:
"""
Load balancer for distributing jobs across multiple GPU devices.
Tracks device usage and provides the least busy device when requested.
"""

def __init__(
self,
device: Optional[
Union[torch.device, str, List[Union[torch.device, str]]]
] = None,
):
"""
Initialize the load balancer with device(s).

:param device: Device specification - can be:
- None: auto-select available GPU (cuda, xpu, npu) or fallback to CPU
- Single device: torch.device or str (e.g., "cuda:0")
- List of devices: List[torch.device | str] for multi-GPU support
"""
# Parse device argument into list of devices
if isinstance(device, list):
# Multi-GPU: validate and convert each device
device_list = [gpu_if_available(d) for d in device]
else:
# Single device: create list with single device
device_list = [gpu_if_available(device)]

self.devices = device_list
self.device_usage = {device: 0 for device in self.devices}
self.lock = Lock()

def get_device(self) -> torch.device:
"""
Get the least busy device. Thread-safe.

:return: The device with the fewest active jobs
"""
with self.lock:
# Find device with minimum usage
device = min(self.device_usage.keys(), key=lambda d: self.device_usage[d])
self.device_usage[device] += 1
return device

def release_device(self, device: torch.device):
"""
Release a device back to the pool. Thread-safe.

:param device: The device to release
"""
with self.lock:
if device in self.device_usage:
self.device_usage[device] -= 1
else:
logger.warning(f"Attempted to release unknown device: {device}")

@staticmethod
def inject_device(func):
"""
Decorator that manages device lifecycle for functions.

The decorated function should have a 'device' parameter. When calling
the wrapped function, pass a DeviceLoadBalancer instance in place of
the device parameter. The decorator will automatically:
1. Get a device from the load balancer
2. Call the function with that device
3. Release the device when complete (even if an exception occurs)

:param func: Function to decorate (must have a 'device' parameter)
:return: Wrapped function that accepts load_balancer instead of device
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
kwargs = dict(bound_args.arguments)

load_balancer: DeviceLoadBalancer = kwargs.pop("device")
device = load_balancer.get_device()
kwargs["device"] = device

try:
return func(**kwargs)
finally:
load_balancer.release_device(device)

return wrapper
Comment on lines +73 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using a decorator to manage the device lifecycle is a clever approach, the current implementation of inject_device introduces ambiguity. It requires the decorated function's device parameter to accept a DeviceLoadBalancer instance at the call site, which is then replaced by a torch.device object within the function. This name-based argument type override can be confusing for developers and static analysis tools.

A more explicit and less magical pattern would be to remove the decorator and use a try...finally block directly in the functions that require a device. This would improve readability and maintainability.

For example, process_file in src/llmcompressor/entrypoints/model_free/process.py could be refactored as follows:

# No decorator here
def process_file(
    ...,
    load_balancer: "DeviceLoadBalancer",
):
    device = load_balancer.get_device()
    try:
        # original function body using `device`
        ...
    finally:
        load_balancer.release_device(device)

This approach is clearer and aligns with how validate_file handles the load_balancer argument, promoting consistency across the codebase.

13 changes: 10 additions & 3 deletions src/llmcompressor/entrypoints/model_free/process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from collections import defaultdict
from collections.abc import Iterator, Mapping
from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import torch
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils.match import match_name
from safetensors.torch import load_file, save_file
from torch.nn import Module

from llmcompressor.entrypoints.model_free.device_balancer import DeviceLoadBalancer
from llmcompressor.entrypoints.model_free.lifecycle import (
calibrate_global_scale,
calibrate_scale_zp,
Expand All @@ -21,6 +22,9 @@
is_microscale_scheme,
)

if TYPE_CHECKING:
pass

__all__ = ["validate_file", "process_file", "process_file_microscale_scheme"]


Expand All @@ -43,7 +47,7 @@ def validate_file(
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
load_balancer: DeviceLoadBalancer,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.
Expand All @@ -52,13 +56,15 @@ def validate_file(
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param load_balancer: device load balancer (unused, kept signature consistency)
"""
tensors = load_file(file_path)
tensors = load_file(file_path, device="meta")

for _, name in iter_quantizable_tensors(tensors, ignore):
validate_weight_for_quantization(tensors[name], scheme, name)


@DeviceLoadBalancer.inject_device
def process_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
Expand Down Expand Up @@ -103,6 +109,7 @@ def process_file(
return total_size, weight_map


@DeviceLoadBalancer.inject_device
def process_file_microscale_scheme(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
Expand Down
Loading