Skip to content

Commit 0f1457f

Browse files
committed
Possibility to override dtype trhough env var
Signed-off-by: Raphael Glon <[email protected]>
1 parent 40a21b7 commit 0f1457f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

docker_images/diffusers/app/pipelines/image_to_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def __init__(self, model_id: str):
4545
if model_id.startswith("hf-internal-testing/")
4646
else {}
4747
)
48-
if torch.cuda.is_available():
48+
env_dtype = os.getenv("TORCH_DTYPE")
49+
if env_dtype:
50+
kwargs["torch_dtype"] = getattr(torch, env_dtype)
51+
elif torch.cuda.is_available():
4952
kwargs["torch_dtype"] = torch.float16
5053
if model_id == "stabilityai/stable-diffusion-xl-refiner-1.0":
5154
kwargs["variant"] = "fp16"

docker_images/diffusers/app/pipelines/text_to_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def __init__(self, model_id: str):
4040
if model_id.startswith("hf-internal-testing/")
4141
else {}
4242
)
43-
if torch.cuda.is_available():
43+
env_dtype = os.getenv("TORCH_DTYPE")
44+
if env_dtype:
45+
kwargs["torch_dtype"] = getattr(torch, env_dtype)
46+
elif torch.cuda.is_available():
4447
kwargs["torch_dtype"] = torch.float16
4548

4649
has_model_index = any(

0 commit comments

Comments
 (0)