We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 94bc7d5 commit 8bc48fcCopy full SHA for 8bc48fc
src/cogkit/utils/load.py
@@ -12,7 +12,9 @@ def load_pipeline(
12
pipeline = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=dtype)
13
if transformer_path is not None:
14
pipeline.transformer.save_config(transformer_path)
15
- pipeline.transformer = pipeline.transformer.from_pretrained(transformer_path)
+ pipeline.transformer = pipeline.transformer.from_pretrained(
16
+ transformer_path, torch_dtype=dtype
17
+ )
18
if lora_model_id_or_path is not None:
19
load_lora_checkpoint(pipeline, lora_model_id_or_path)
20
return pipeline
0 commit comments