From 3b78824e9bd97c335abb532bd8ab195a77017e79 Mon Sep 17 00:00:00 2001 From: Hyunho Richard Lee Date: Fri, 15 Jul 2022 14:25:02 -0400 Subject: [PATCH] MEGA_FULL needs dtype = jnp.float32 --- backend/dalle_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/dalle_model.py b/backend/dalle_model.py index 485836ab3..0889fb6b2 100644 --- a/backend/dalle_model.py +++ b/backend/dalle_model.py @@ -48,7 +48,7 @@ class DalleModel: def __init__(self, model_version: ModelSize) -> None: if model_version == ModelSize.MEGA_FULL: dalle_model = DALLE_MODEL_MEGA_FULL - dtype = jnp.float16 + dtype = jnp.float32 elif model_version == ModelSize.MEGA: dalle_model = DALLE_MODEL_MEGA dtype = jnp.float16