Skip to content

Commit 9231f55

Browse files
authored
Fix is_using_oneflow_backend check (#1112)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced backend detection logic for improved compatibility with the OneFlow library. - Added a function to check for OneFlow library availability and CUDA support. - **Bug Fixes** - Improved messaging for cases when the OneFlow backend is not detected. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 7c32525 commit 9231f55

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Union
22

3+
import oneflow
4+
35
import torch
46
from comfy import model_management
57
from comfy.model_base import BaseModel, SVD_img2vid
@@ -9,6 +11,7 @@
911
OneflowDeployableModule as DeployableModule,
1012
)
1113
from onediff.utils import set_boolean_env_var
14+
from onediff.utils.import_utils import is_oneflow_available
1215

1316
from ..patch_management import create_patch_executor, PatchType
1417

@@ -63,6 +66,15 @@ def set_environment_for_svd_img2vid(model: ModelPatcher):
6366

6467

6568
def is_using_oneflow_backend(module):
69+
# First, check if oneflow is available and CUDA is enabled
70+
if is_oneflow_available() and not oneflow.cuda.is_available():
71+
print("OneFlow CUDA support is not available")
72+
return False
73+
74+
# Check if the module
75+
if isinstance(module, oneflow.nn.Module):
76+
return True
77+
6678
dc_patch_executor = create_patch_executor(PatchType.DCUNetExecutorPatch)
6779
if isinstance(module, ModelPatcher):
6880
deep_cache_module = dc_patch_executor.get_patch(module)
@@ -85,7 +97,17 @@ def is_using_oneflow_backend(module):
8597
if isinstance(module, DeployableModule):
8698
return True
8799

88-
raise RuntimeError("")
100+
if hasattr(module, "parameters"):
101+
for param in module.parameters():
102+
if isinstance(param, oneflow.Tensor):
103+
return True
104+
105+
warn_msg = (
106+
f"OneFlow backend is not detected for the module, the module is {type(module)}"
107+
)
108+
print(warn_msg)
109+
# If none of the above conditions are met, it's not using OneFlow backend
110+
return False
89111

90112

91113
def clear_deployable_module_cache_and_unbind(

0 commit comments

Comments
 (0)