diff --git a/nodes/models/flux.py b/nodes/models/flux.py index bc3b074..37767be 100644 --- a/nodes/models/flux.py +++ b/nodes/models/flux.py @@ -260,7 +260,7 @@ def load_model( data_type: str, **kwargs, ) -> tuple[FluxTransformer2DModel]: - device = f"cuda:{device_id}" + device = torch.device(f"cuda:{device_id}") prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0] for prefix in prefixes: if os.path.exists(os.path.join(prefix, model_path)):