diff --git a/backend/dalle_model.py b/backend/dalle_model.py index 485836ab3..0889fb6b2 100644 --- a/backend/dalle_model.py +++ b/backend/dalle_model.py @@ -48,7 +48,7 @@ class DalleModel: def __init__(self, model_version: ModelSize) -> None: if model_version == ModelSize.MEGA_FULL: dalle_model = DALLE_MODEL_MEGA_FULL - dtype = jnp.float16 + dtype = jnp.float32 elif model_version == ModelSize.MEGA: dalle_model = DALLE_MODEL_MEGA dtype = jnp.float16