Skip to content

Commit 8bc48fc

Browse files
committed
[fix] Use torch_dtype when loading transformer model
1 parent 94bc7d5 commit 8bc48fc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/cogkit/utils/load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ def load_pipeline(
1212
pipeline = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=dtype)
1313
if transformer_path is not None:
1414
pipeline.transformer.save_config(transformer_path)
15-
pipeline.transformer = pipeline.transformer.from_pretrained(transformer_path)
15+
pipeline.transformer = pipeline.transformer.from_pretrained(
16+
transformer_path, torch_dtype=dtype
17+
)
1618
if lora_model_id_or_path is not None:
1719
load_lora_checkpoint(pipeline, lora_model_id_or_path)
1820
return pipeline

0 commit comments

Comments
 (0)