Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/llmcompressor/modifiers/logarithmic_equalization/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import torch
from torch.nn import Module

Expand Down Expand Up @@ -52,8 +50,8 @@ class LogarithmicEqualizationModifier(SmoothQuantModifier):
"""

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
self, balance_layers: list[Module], activation_scales: torch.Tensor
) -> torch.Tensor:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activations and the following weights.
Expand Down
20 changes: 10 additions & 10 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Optional

import torch
from compressed_tensors.utils import align_module_device, match_modules_set
Expand Down Expand Up @@ -51,7 +51,7 @@ class SmoothQuantMapping:

smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
balance_layers: list[Module]


class SmoothQuantModifier(Modifier):
Expand Down Expand Up @@ -96,15 +96,15 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
mappings: Optional[list[tuple | list]] = None
ignore: Optional[list[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List[SmoothQuantMapping]] = Field(
resolved_mappings_: Optional[list[SmoothQuantMapping]] = Field(
default=None, repr=False
)
scales_: Optional[Dict] = Field(default=None, repr=False)
scales_: Optional[dict] = Field(default=None, repr=False)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
def _infer_mappings_from_model(
self,
model: Module,
) -> List[Tuple]:
) -> list[tuple]:
if self.mappings is not None:
return self.mappings

Expand All @@ -188,7 +188,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -309,8 +309,8 @@ def smooth(module):
del self.scales_[mapping.smooth_name]

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
self, balance_layers: list[Module], activation_scales: torch.Tensor
) -> torch.Tensor:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activation and the following weights
Expand Down