Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Load pre-trained weights (.pth files etc.) into a flux model #2164

Open
ejmeitz opened this issue Jan 16, 2023 · 6 comments

Comments

@ejmeitz
Copy link

ejmeitz commented Jan 16, 2023

Motivation and description

I think it would be useful to load pre-trained weights from PyTorch or Tensorflow. I'm sure this has been discussed before (e.g., in the PyTorch feature parity doc), but I could not an open issue on this.

Possible Implementation

My current work around takes the .pth file and opens it with Pickle.jl. I am still working to parse the resulting dictionary to create a Flux model. It would make my life much easier if there was an associated flux function to just load pre-trained weights and evaluate a model.

@ToucheSir
Copy link
Member

Are you aware of https://fluxml.ai/Flux.jl/stable/saving/#Flux.loadmodel!? I'm not sure it's our place to be including functionality for converting between PyTorch and Flux model structures given the lack of uniformity on the PyTorch side (excepting specific cases like Metalhead.jl where we have a known, limited set of models to map from).

@ejmeitz
Copy link
Author

ejmeitz commented Jan 16, 2023

Yes, I believe we discussed this on Slack yesterday haha. That is what I plan on using for my implementation as it is only for a single model. Just thought I'd post an issue here since I also saw it in several open threads online and in the Pytorch feature parity document in this repo.

@ToucheSir
Copy link
Member

I would say it's partially covered by the "We should expose the possibility to load pretrained weights" point under "PyTorch Extras" in #1431. As for more general solutions, were someone to come up with a general Dict -> nested struct transformation which works with most PyTorch models, we could consider depending/integrating/advertising it on the Flux side.

@ejmeitz
Copy link
Author

ejmeitz commented Jan 16, 2023

I'll let you know if what I come up with is general enough.

Flux should have all the same layer types & hyperparameters as PyTorch correct (with different names)?

@ToucheSir
Copy link
Member

Not necessarily, which is another reason thisi s difficult to generalize. In general we try to keep to close to PyTorch if there's no good reason to diverge, but that's not a hard rule.

@CarloLucibello
Copy link
Member

Some scripts for porting weights can be found in the Metalhead repo https://github.com/FluxML/Metalhead.jl/tree/master/scripts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants