Skip to content

[FEATURE] Load model config AND weights from local folder (support locally cloned HF hub models) #2338

Open
@rwightman

Description

@rwightman

Is your feature request related to a problem? Please describe.

Right now to load pretrained models w/ associated pretained config (preprocess + head details) pretrained=True must be used with

  1. the model name consisting of a string specifing builtin model + pretrained tag (one of timm.list_pretrained())
    OR
  2. consisting of a string specifying a model repo on the HuggingFace (HF) hub hf-hub:repo_name

For option 1 above, the pretrained_cfg is loaded from builtins in the library. The builtin config can specify a weight location at a url or a specific HF hub repo + filename.

For option 2, the pretrained_cfg is loaded from the specified repo in the config.json file, and the weight is loaded from that same repo.

In any case, if it is desired to using an existing pretrained_cfg by either mechanism above, the create_model factory will accept a pretrained_cfg_overlay argument, it should be a dict, and each key-value of that dict will override the values in the originally sourced config. This allows one to load weights locally by passing a file= key that will override any url or hf_hub entry, but cannot change how/where the config is sourced from.

Example:

timm.create_model(
  'swinv2_large_window12to16_192to256',
  pretrained=True,
  pretrained_cfg_overlay=dict(file='path/to/checkpoint'),
)

Describe the solution you'd like
It should be possible to call timm.create_model('...', pretrained=True) and specify a local folder from which both the config and weights will be sourced, as with transformers.AutoModel.from_pretrained()

The hardest part of adding this is figuring out how to integrate passing a folder into the create_model API. The pretrained_cfg_overlay was always a bit clunky/kludgy ... ideally this would be cleaner.

Based on existing use the two options that initially stand out to me

  1. if model name isn't found, check for folder at the specified string timm.create_model('/blah/blah/my-resnet50')
  2. prefix like hf-hub with something like local: , 'folder: ... timm.create_model('local:/blah/blah/my-resnet50')

I'm liking 2 because it's a bit safer and more explicit, parallels hf-hub use. May be some considerations re timm wrapper in transformers

Additional context
This is useful generally but would be particularly useful with new timm model wrapper in transformers to maintain full API compability. Users expect to be able to clone a model repo from the hub and point to a local folder with those files when calling from_pretrained().

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions