Skip to content

Commit 3623d08

Browse files
authored
fix: symmetric compatibility (#555)
1 parent 0dc3392 commit 3623d08

File tree

18 files changed

+111
-22
lines changed

18 files changed

+111
-22
lines changed

src/pruna/algorithms/deepcache.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,21 @@ class DeepCache(PrunaAlgorithmBase):
5050
"diffusers_int8",
5151
"quanto",
5252
"sage_attn",
53+
"hyper",
54+
"padding_pruning",
5355
]
54-
compatible_after: Iterable[str] = ["stable_fast", "torch_compile"]
56+
compatible_after: Iterable[str] = [
57+
"stable_fast",
58+
"torch_compile",
59+
"img2img_denoise",
60+
"realesrgan_upscale",
61+
"text_to_image_distillation_inplace_perp",
62+
"text_to_image_distillation_lora",
63+
"text_to_image_distillation_perp",
64+
"text_to_image_inplace_perp",
65+
"text_to_image_lora",
66+
"text_to_image_perp",
67+
]
5568

5669
def get_hyperparameters(self) -> list:
5770
"""

src/pruna/algorithms/denoise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class Img2ImgDenoise(PrunaAlgorithmBase):
7575
"torchao",
7676
"qkv_diffusers",
7777
"ring_attn",
78+
"hyper",
7879
]
7980

8081
def get_hyperparameters(self) -> list:

src/pruna/algorithms/fastercache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class FasterCache(PrunaAlgorithmBase):
5757
processor_required: bool = False
5858
dataset_required: bool = False
5959
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
60-
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"]
60+
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn", "hyper", "padding_pruning"]
61+
compatible_after: Iterable[str] = ["img2img_denoise", "realesrgan_upscale"]
6162

6263
def get_hyperparameters(self) -> list:
6364
"""

src/pruna/algorithms/flash_attn3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class FlashAttn3(PrunaAlgorithmBase):
5151
processor_required: bool = False
5252
runs_on: list[str] = ["cuda", "accelerate"]
5353
dataset_required: bool = False
54-
compatible_before: Iterable[str] = ["torchao"]
54+
compatible_before: Iterable[str] = ["torchao", "padding_pruning"]
5555
compatible_after: Iterable[str] = ["fora", "torch_compile"]
5656

5757
def model_check_fn(self, model: Any) -> bool:

src/pruna/algorithms/fora.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ class FORA(PrunaAlgorithmBase):
5050
"hqq_diffusers",
5151
"torchao",
5252
"flash_attn3",
53-
"sage_attn"
53+
"sage_attn",
54+
"hyper",
55+
"padding_pruning",
5456
]
55-
compatible_after: Iterable[str] = ["stable_fast", "torch_compile"]
57+
compatible_after: Iterable[str] = ["stable_fast", "torch_compile", "img2img_denoise", "realesrgan_upscale"]
5658

5759
def get_hyperparameters(self) -> list:
5860
"""
@@ -80,9 +82,7 @@ def get_hyperparameters(self) -> list:
8082
"backbone_calls_per_step",
8183
sequence=range(1, 4),
8284
default_value=1,
83-
meta=dict(
84-
desc="Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."
85-
),
85+
meta=dict(desc="Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."),
8686
),
8787
]
8888

src/pruna/algorithms/half.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class Half(PrunaAlgorithmBase):
5252
"ifw",
5353
"whisper_s2t",
5454
"sage_attn",
55+
"hyper",
56+
"ipex_llm",
57+
"text_to_text_inplace_perp",
58+
"text_to_text_lora",
59+
"text_to_text_perp",
60+
"x_fast",
5561
]
5662

5763
def model_check_fn(self, model: Any) -> bool:

src/pruna/algorithms/hqq_diffusers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,17 @@ class HQQDiffusers(PrunaAlgorithmBase):
6464
processor_required: bool = False
6565
runs_on: list[str] = ["cuda"]
6666
dataset_required: bool = False
67-
compatible_before: Iterable[str] = ["qkv_diffusers"]
68-
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"]
67+
compatible_before: Iterable[str] = ["qkv_diffusers", "padding_pruning"]
68+
compatible_after: Iterable[str] = [
69+
"deepcache",
70+
"fastercache",
71+
"fora",
72+
"pab",
73+
"torch_compile",
74+
"sage_attn",
75+
"img2img_denoise",
76+
"realesrgan_upscale",
77+
]
6978
disjointly_compatible_before: Iterable[str] = []
7079
disjointly_compatible_after: Iterable[str] = ["torchao"]
7180

src/pruna/algorithms/huggingface_diffusers_int8.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,18 @@ class DiffusersInt8(PrunaAlgorithmBase):
6060
dataset_required: bool = False
6161
runs_on: list[str] = ["cuda", "accelerate"]
6262
save_fn: None = None
63-
compatible_before: Iterable[str] = ["qkv_diffusers"]
64-
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"]
63+
compatible_before: Iterable[str] = ["qkv_diffusers", "padding_pruning"]
64+
compatible_after: Iterable[str] = [
65+
"deepcache",
66+
"fastercache",
67+
"fora",
68+
"pab",
69+
"torch_compile",
70+
"sage_attn",
71+
"img2img_denoise",
72+
"hyper",
73+
"realesrgan_upscale",
74+
]
6575

6676
def get_hyperparameters(self) -> list:
6777
"""

src/pruna/algorithms/llm_compressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class LLMCompressor(PrunaAlgorithmBase):
5252
dataset_required: bool = True
5353
runs_on: list[str] = ["cuda"]
5454
compatible_before: Iterable[str] = []
55-
compatible_after: Iterable[str] = []
55+
compatible_after: Iterable[str] = ["sage_attn"]
5656

5757
def get_hyperparameters(self) -> list:
5858
"""

src/pruna/algorithms/pab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class PAB(PrunaAlgorithmBase):
5353
processor_required: bool = False
5454
dataset_required: bool = False
5555
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
56-
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"]
57-
compatible_after: Iterable[str] = []
56+
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn", "hyper", "padding_pruning"]
57+
compatible_after: Iterable[str] = ["img2img_denoise", "realesrgan_upscale"]
5858

5959
def get_hyperparameters(self) -> list:
6060
"""

0 commit comments

Comments
 (0)