Skip to content

Commit 47c7e92

Browse files
support loading safetensors format. (#123)
Co-authored-by: Juan Acevedo <[email protected]>
1 parent 6271ab7 commit 47c7e92

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -196,21 +196,20 @@ def load_diffusers_checkpoint(self):
196196
precision=precision,
197197
)
198198

199-
if len(self.config.unet_checkpoint) > 0:
200-
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
201-
self.config.unet_checkpoint,
202-
split_head_dim=self.config.split_head_dim,
203-
norm_num_groups=self.config.norm_num_groups,
204-
attention_kernel=self.config.attention,
205-
flash_block_sizes=flash_block_sizes,
206-
dtype=self.activations_dtype,
207-
weights_dtype=self.weights_dtype,
208-
mesh=self.mesh,
209-
)
210-
params["unet"] = unet_params
211-
pipeline.unet = unet
212-
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
213-
199+
if len(self.config.unet_checkpoint) > 0:
200+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
201+
self.config.unet_checkpoint,
202+
split_head_dim=self.config.split_head_dim,
203+
norm_num_groups=self.config.norm_num_groups,
204+
attention_kernel=self.config.attention,
205+
flash_block_sizes=flash_block_sizes,
206+
dtype=self.activations_dtype,
207+
weights_dtype=self.weights_dtype,
208+
mesh=self.mesh,
209+
)
210+
params["unet"] = unet_params
211+
pipeline.unet = unet
212+
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
214213
return pipeline, params
215214

216215
def save_checkpoint(self, train_step, pipeline, params, train_states):

src/maxdiffusion/models/modeling_flax_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
FLAX_WEIGHTS_NAME,
3535
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3636
WEIGHTS_NAME,
37+
SAFETENSORS_WEIGHTS_NAME,
3738
PushToHubMixin,
3839
logging,
3940
)
@@ -331,9 +332,12 @@ def from_pretrained(
331332
)
332333
if os.path.isdir(pretrained_path_with_subfolder):
333334
if from_pt:
334-
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
335+
if os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
336+
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
337+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME)):
338+
model_file = os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME)
339+
else:
335340
raise EnvironmentError(f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} ")
336-
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
337341
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
338342
# Load from a Flax checkpoint
339343
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)

0 commit comments

Comments
 (0)