Skip to content

Commit dd44897

Browse files
minettekaumsimlanggsprochette
authored
feat: add recoverer algorithms (#491)
* added PERP recovery algorithms without distillion * adding revocery algorithms with distillation * fixed linting errors * fixing typo * fixing pr comments * fixed typo * fixing typo * fixing linting error * fixing style * Add co-authors Co-authored-by: Simon Langrieger <simon.langrieger@pruna.ai> Co-authored-by: Gaspar Rochette <gaspar.rochette@pruna.ai> --------- Co-authored-by: Simon Langrieger <simon.langrieger@pruna.ai> Co-authored-by: Gaspar Rochette <gaspar.rochette@pruna.ai>
1 parent 11bae28 commit dd44897

32 files changed

+4625
-1
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ dependencies = [
144144
"vbench-pruna; sys_platform != 'darwin'",
145145
"imageio-ffmpeg",
146146
"jaxtyping",
147-
"peft>=0.17.1",
147+
"peft>=0.18.0",
148+
"trl<=0.21.0",
149+
"termcolor==2.3.0",
148150
]
149151

150152
[project.optional-dependencies]

src/pruna/algorithms/base/tags.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class AlgorithmTag(Enum):
7979
"resampler",
8080
"Resamplers change the shape of image or video latents during generation to speed up inference.",
8181
)
82+
RECOVERER = (
83+
"recoverer",
84+
"Recovery restores the performance of a model after compression.",
85+
)
8286
DECODER = (
8387
"decoder",
8488
"Decoders speed up autoregressive generation by changing their decoding strategy to be more parallelizable.",
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Iterable
18+
19+
from pruna.algorithms.base.tags import AlgorithmTag
20+
from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer
21+
22+
23+
class TextToImagePERPDistillation(PERPRecoverer):
24+
"""
25+
PERP distillation recoverer for text-to-image models.
26+
27+
This recoverer is a general purpose PERP recoverer for text-to-image models using norm and bias finetuning
28+
as well as LoRA layers.
29+
30+
Parameters
31+
----------
32+
use_lora : bool
33+
Whether to use LoRA adapters.
34+
use_in_place : bool
35+
Whether to use norm and bias finetuning which will modify the model in place.
36+
"""
37+
38+
group_tags: list[AlgorithmTag] = [AlgorithmTag.DISTILLER, AlgorithmTag.RECOVERER] # type: ignore[attr-defined]
39+
algorithm_name = "text_to_image_distillation_perp"
40+
tokenizer_required = False
41+
compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache"]
42+
compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"]
43+
runs_on: list[str] = ["cuda"]
44+
45+
def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None:
46+
super().__init__(task_name="text_to_image", use_lora=use_lora, use_in_place=use_in_place, is_distillation=True)
47+
48+
49+
class TextToImageInPlacePERPDistillation(TextToImagePERPDistillation):
50+
"""
51+
PERP distillation recoverer for text-to-image models without LoRA adapters.
52+
53+
This is the same as ``text_to_image_distillation_perp``, but without LoRA layers which add extra computations and
54+
thus slow down the inference of the final model.
55+
"""
56+
57+
algorithm_name = "text_to_image_distillation_inplace_perp"
58+
59+
def __init__(self) -> None:
60+
super().__init__(use_lora=False, use_in_place=True)
61+
62+
63+
class TextToImageLoraDistillation(TextToImagePERPDistillation):
64+
"""
65+
LoRA distillation recoverer for text-to-image models.
66+
67+
This recoverer attaches LoRA adapters to the model and uses them for distillation.
68+
"""
69+
70+
algorithm_name = "text_to_image_distillation_lora"
71+
72+
def __init__(self) -> None:
73+
super().__init__(use_lora=True, use_in_place=False)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC, abstractmethod
18+
from typing import Any
19+
20+
import torch
21+
22+
from pruna.config.smash_config import SmashConfigPrefixWrapper
23+
24+
25+
class PrunaAdapter(ABC):
26+
"""Base class for adapters, defining which parameters to finetune for recovery."""
27+
28+
@property
29+
@abstractmethod
30+
def adapter_prefix(self) -> str:
31+
"""The prefix of the adapter to use in the config."""
32+
pass
33+
34+
@classmethod
35+
@abstractmethod
36+
def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
37+
"""
38+
Configure all algorithm-specific hyperparameters with ConfigSpace.
39+
40+
Parameters
41+
----------
42+
task_name : str
43+
The name of the task, e.g. "text-to-image" or "text-to-text".
44+
**override_defaults : Any
45+
Values used to override the default hyperparameters when using multiple finetuners together.
46+
47+
Returns
48+
-------
49+
list
50+
The hyperparameters.
51+
"""
52+
pass
53+
54+
@classmethod
55+
@abstractmethod
56+
def activate(
57+
cls,
58+
model: torch.nn.Module,
59+
smash_config: SmashConfigPrefixWrapper,
60+
seed: int | None = None,
61+
) -> tuple[torch.nn.Module, int, int]:
62+
"""
63+
Activate or create the parameters in the model corresponding to the adapter.
64+
65+
Parameters
66+
----------
67+
model : torch.nn.Module
68+
The model to apply the component to.
69+
smash_config : SmashConfigPrefixWrapper
70+
The configuration for the component.
71+
seed : int
72+
The seed to use for the adapter if it requires initialization.
73+
74+
Returns
75+
-------
76+
torch.nn.Module
77+
The model with the adapter activated.
78+
int
79+
The number of trainable parameters.
80+
int
81+
The number of skipped parameters.
82+
"""
83+
pass
84+
85+
@classmethod
86+
def pre_smash_hook(
87+
cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int | None = None
88+
) -> None:
89+
"""
90+
Optional hook to prepare the model/config before smashing.
91+
92+
Parameters
93+
----------
94+
model : torch.nn.Module
95+
The model to prepare.
96+
smash_config : SmashConfigPrefixWrapper
97+
Configuration scoped to this adapter.
98+
seed : int | None
99+
Optional seed for deterministic initialization.
100+
"""
101+
pass
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils
18+
19+
20+
class BiasAdapter(PrunaAdapter):
21+
"""Adapter for bias finetuning."""
22+
23+
adapter_prefix = "bias"
24+
25+
@classmethod
26+
def get_hyperparameters(cls, *args, **kwargs) -> list:
27+
"""
28+
Configure all method-specific hyperparameters with ConfigSpace.
29+
30+
Parameters
31+
----------
32+
*args : Any
33+
Unused arguments.
34+
**kwargs : Any
35+
Unused keyword arguments.
36+
37+
Returns
38+
-------
39+
list
40+
The hyperparameters.
41+
"""
42+
return []
43+
44+
@classmethod
45+
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
46+
"""
47+
Activate all biases for training.
48+
49+
Parameters
50+
----------
51+
model : torch.nn.Module
52+
The model containing the biases.
53+
*args : Any
54+
Unused additional arguments.
55+
**kwargs : Any
56+
Unused additional keyword arguments.
57+
58+
Returns
59+
-------
60+
torch.nn.Module
61+
The model with the biases activated.
62+
int
63+
The number of trainable bias parameters.
64+
int
65+
The number of skipped bias parameters.
66+
"""
67+
num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(model, target_modules=("bias",))
68+
return model, num_activ_param, num_skip_param
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
17+
import torch
18+
19+
from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils
20+
from pruna.logging.logger import pruna_logger
21+
22+
23+
class HeadAdapter(PrunaAdapter):
24+
"""Adapter for finetuning the model's head while keeping the backbone as is."""
25+
26+
adapter_prefix = "head"
27+
28+
@classmethod
29+
def get_hyperparameters(cls, *args, **kwargs) -> list:
30+
"""
31+
Configure all method-specific hyperparameters with ConfigSpace.
32+
33+
Parameters
34+
----------
35+
*args : tuple
36+
The arguments for the adapter.
37+
**kwargs : dict
38+
The hyperparameters for the adapter.
39+
40+
Returns
41+
-------
42+
list
43+
The hyperparameters.
44+
"""
45+
return []
46+
47+
@classmethod
48+
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
49+
"""
50+
Activate the model's head for training.
51+
52+
Parameters
53+
----------
54+
model : torch.nn.Module
55+
The model containing the head.
56+
*args : tuple
57+
The arguments for the adapter.
58+
**kwargs : dict
59+
The hyperparameters for the adapter.
60+
61+
Returns
62+
-------
63+
torch.nn.Module
64+
The model with the head activated.
65+
int
66+
The number of trainable head parameters.
67+
int
68+
The number of skipped head parameters.
69+
"""
70+
# find head from type and name
71+
model_heads = [
72+
component
73+
for comp_name, component in inspect.getmembers(model)
74+
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
75+
]
76+
if len(model_heads) != 1:
77+
# = 0: model with no head, e.g. diffusers
78+
# > 1: model with multiple heads, e.g. for localization, not currently supported
79+
model_head_names = [
80+
comp_name
81+
for comp_name, component in inspect.getmembers(model)
82+
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
83+
]
84+
pruna_logger.warning(
85+
f"Found multiple heads but expected only one: {model_head_names}. Skipping head finetuning."
86+
)
87+
return model, 0, 0
88+
model_head = model_heads[0]
89+
90+
# unfreeze head parameters, recording the number of trainable and skipped parameters
91+
num_activ_param, num_skip_param = 0, 0
92+
for param in model_head.parameters():
93+
if utils.is_trainable(param):
94+
param.requires_grad = True
95+
num_activ_param += int(param.numel())
96+
else:
97+
num_skip_param += int(param.numel())
98+
99+
return model, num_activ_param, num_skip_param

0 commit comments

Comments
 (0)