Skip to content

Commit 7c32525

Browse files
authored
ResolutionSpeedupChecker (#1102)
![workflow (3)](https://github.com/user-attachments/assets/6f5839f2-1139-4e0b-8518-ea559842d909) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a `ResolutionSpeedupChecker` class for validating input sample dimensions against a configuration. - Implemented a mechanism to enable or disable deployable functionality via a global variable and associated functions. - **Bug Fixes** - Enhanced the constructor of the `Hijacker` class to ensure robust handling of the `funcs_list` parameter. - **Documentation** - Updated mappings to include the new `ResolutionSpeedupChecker` class for better recognition within the system. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a3cc989 commit 7c32525

File tree

5 files changed

+90
-2
lines changed

5 files changed

+90
-2
lines changed

onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import collections
2+
import os
3+
4+
import yaml
5+
from onediff.infer_compiler.backends.nexfort import fallback_to_eager
26

37
from ..modules.nexfort.booster_basic import BasicNexFortBoosterExecutor
48

@@ -52,8 +56,40 @@ def apply(
5256
)
5357

5458

59+
class ResolutionSpeedupChecker:
60+
current_dir = os.path.dirname(os.path.abspath(__file__))
61+
config_path = os.path.join(current_dir, "resolutions_config.yaml")
62+
with open(config_path, "r") as file:
63+
resolutions = yaml.safe_load(file)
64+
height_width_dict = {x["height"]: x["width"] for x in resolutions["resolutions"]}
65+
66+
@classmethod
67+
def INPUT_TYPES(s):
68+
return {
69+
"required": {
70+
"samples": ("LATENT",),
71+
}
72+
}
73+
74+
RETURN_TYPES = ("LATENT",)
75+
FUNCTION = "check"
76+
77+
def check(self, samples):
78+
_, _, H, W = samples["samples"].shape
79+
H, W = H * 8, W * 8
80+
if H in self.height_width_dict and self.height_width_dict[H] == W:
81+
fallback_to_eager(True)
82+
else:
83+
fallback_to_eager(False)
84+
return (samples,)
85+
86+
5587
NODE_CLASS_MAPPINGS = {
5688
"OneDiffNexfortBooster": OneDiffNexfortBooster,
89+
"ResolutionSpeedupChecker": ResolutionSpeedupChecker,
5790
}
5891

59-
NODE_DISPLAY_NAME_MAPPINGS = {"OneDiffNexfortBooster": "Nexfort Booster - OneDiff"}
92+
NODE_DISPLAY_NAME_MAPPINGS = {
93+
"OneDiffNexfortBooster": "Nexfort Booster - OneDiff",
94+
"ResolutionSpeedupChecker": "Speedup Checker - Resolution",
95+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
resolutions:
2+
- width: 1024
3+
height: 1024
4+
aspect_ratio: '1:1 Square'
5+
- width: 1152
6+
height: 896
7+
aspect_ratio: '9:7'
8+
- width: 896
9+
height: 1152
10+
aspect_ratio: '7:9'
11+
- width: 1216
12+
height: 832
13+
aspect_ratio: '19:13'
14+
- width: 832
15+
height: 1216
16+
aspect_ratio: '13:19'
17+
- width: 1344
18+
height: 768
19+
aspect_ratio: '7:4 Horizontal'
20+
- width: 768
21+
height: 1344
22+
aspect_ratio: '4:7 Vertical'
23+
- width: 1536
24+
height: 640
25+
aspect_ratio: '12:5 Horizontal'
26+
- width: 640
27+
height: 1536
28+
aspect_ratio: '5:12 Vertical'
29+
- width: 512
30+
height: 512
31+
aspect_ratio: '1:1 Square'

onediff_comfy_nodes/modules/sd_hijack_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ class Hijacker:
162162
"""
163163

164164
def __init__(self, funcs_list=[]):
165-
self.funcs_list = funcs_list
165+
if funcs_list and isinstance(funcs_list, List):
166+
self.funcs_list = funcs_list.copy()
167+
else:
168+
self.funcs_list = []
166169
self.unhijack_funcs = []
167170

168171
def hijack(self, last=True):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from . import nexfort as _nexfort_backend
2+
from .deployable_module import fallback_to_eager

src/onediff/infer_compiler/backends/nexfort/deployable_module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@
66

77
from ..deployable_module import DeployableModule
88

9+
DISABLE_DEPLOYABLE = False
10+
11+
12+
def fallback_to_eager(value: bool = False):
13+
"""
14+
Set the DISABLE_DEPLOYABLE environment variable.
15+
16+
Args:
17+
value (bool): The value to set for the environment variable.
18+
"""
19+
global DISABLE_DEPLOYABLE
20+
DISABLE_DEPLOYABLE = not value
21+
922

1023
class NexfortDeployableModule(DeployableModule):
1124
def __init__(self, compiled_module, torch_module):
@@ -23,6 +36,8 @@ def __init__(self, compiled_module, torch_module):
2336
object.__setattr__(self, "_buffers", compiled_module._orig_mod._buffers)
2437

2538
def forward(self, *args, **kwargs):
39+
if DISABLE_DEPLOYABLE:
40+
return self._torch_module(*args, **kwargs)
2641
with torch._dynamo.utils.disable_cache_limit():
2742
return self._deployable_module_model(*args, **kwargs)
2843

@@ -34,6 +49,8 @@ def _create_deployable_function(
3449
compiled_model, torch_module: FunctionType = None
3550
) -> FunctionType:
3651
def deploy_function(*args, **kwargs):
52+
if DISABLE_DEPLOYABLE:
53+
return torch_module(*args, **kwargs)
3754
with torch._dynamo.utils.disable_cache_limit():
3855
return compiled_model(*args, **kwargs)
3956

0 commit comments

Comments
 (0)