Skip to content

Commit d5d131b

Browse files
committed
add dwt checkpoint load
1 parent c753cd6 commit d5d131b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

rudalle/vae/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
2222
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
2323
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
2424
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
25-
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
25+
if dwt:
26+
vae.load_state_dict(checkpoint['state_dict'])
27+
else:
28+
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
2629
print('vae --> ready')
2730
return vae

0 commit comments

Comments
 (0)