Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/llmcompressor/modifiers/logarithmic_equalization/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import torch
from torch.nn import Module

Expand Down Expand Up @@ -52,8 +50,8 @@ class LogarithmicEqualizationModifier(SmoothQuantModifier):
"""

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
self, balance_layers: list[Module], activation_scales: torch.Tensor
) -> torch.Tensor:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activations and the following weights.
Expand Down
24 changes: 12 additions & 12 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable

import torch
from compressed_tensors.utils import align_module_device, match_modules_set
Expand Down Expand Up @@ -51,7 +51,7 @@ class SmoothQuantMapping:

smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
balance_layers: list[Module]


class SmoothQuantModifier(Modifier):
Expand Down Expand Up @@ -96,15 +96,15 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
mappings: list[tuple | list] | None = None
ignore: list[str] | None = None
num_calibration_steps: int | None = None
calibration_function: Callable | None = None

resolved_mappings_: Optional[List[SmoothQuantMapping]] = Field(
resolved_mappings_: list[SmoothQuantMapping] | None = Field(
default=None, repr=False
)
scales_: Optional[Dict] = Field(default=None, repr=False)
scales_: dict | None = Field(default=None, repr=False)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
def _infer_mappings_from_model(
self,
model: Module,
) -> List[Tuple]:
) -> list[tuple]:
if self.mappings is not None:
return self.mappings

Expand All @@ -188,7 +188,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -309,8 +309,8 @@ def smooth(module):
del self.scales_[mapping.smooth_name]

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
self, balance_layers: list[Module], activation_scales: torch.Tensor
) -> torch.Tensor:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activation and the following weights
Expand Down