Skip to content

Load models from Pytorch#1658

Draft
milankl wants to merge 2 commits intoLuxDL:mainfrom
SpeedyWeather:mk/torchloader
Draft

Load models from Pytorch#1658
milankl wants to merge 2 commits intoLuxDL:mainfrom
SpeedyWeather:mk/torchloader

Conversation

@milankl
Copy link

@milankl milankl commented Feb 3, 2026

Just an initial draft following #1657. Current questions

  • What Torch formats should we support? npz, pt, pth
  • How should the extension be triggered? using NPZ for npz okay but what about .pt?
  • How to interpret the model architecture in the file?

@gemini-code-assist
Copy link

Summary of Changes

Hello @milankl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the foundational elements for loading PyTorch models, specifically those saved in the NPZ format, into the Lux deep learning framework. It establishes the necessary package dependencies and a new extension that provides a basic mechanism to read NPZ files, reconstruct a predefined Lux model architecture, and populate its weights. This is an initial draft aimed at exploring the integration of external model formats.

Highlights

  • NPZ Dependency and Extension: The NPZ package has been added as a dependency and a new extension, NPZExt, has been introduced to handle operations related to NPZ files.
  • Pytorch Model Loading Infrastructure: Initial abstract types DataLoader and PytorchLoader are defined, along with a concrete NPZLoader struct, to lay the groundwork for loading PyTorch models.
  • Lux.load Implementation for NPZ: A Lux.load method is implemented for NPZLoader which reads .npz files, extracts model architecture information (currently hardcoded), constructs a Lux model, initializes its parameters, and loads weights from the NPZ file into the Lux model.
  • Project File Updates: The Project.toml file was updated to include the NPZ dependency and extension, and minor reordering of sections occurred, along with the removal of commented-out precompile extensions.
Changelog
  • Project.toml
    • Added NPZ to the [deps] section.
    • Added NPZExt to the [extensions] section.
    • Reordered the authors and [workspace] sections.
    • Removed commented-out Reactant precompile extensions.
  • ext/NPZExt/NPZExt.jl
    • New file added to implement the NPZExt module.
    • Defined DataLoader, PytorchLoader, PTLoader, and NPZLoader types.
    • Implemented Lux.load for NPZLoader to read .npz files.
    • Includes logic to extract model architecture (currently hardcoded) and construct a Lux Chain model.
    • Contains functionality to initialize Lux model parameters and load weights from the NPZ file.
    • Includes TODO comments highlighting areas for future development, such as dynamic architecture interpretation and zero initialization.
Activity
  • The author, milankl, opened this pull request as an initial draft.
  • The PR description includes open questions for discussion regarding supported Torch formats, how extensions should be triggered, and how to interpret model architecture from files, indicating it's in an early stage of development and seeking feedback.
  • The PR references issue Load PyTorch model weights into Lux #1657, suggesting this work is a direct follow-up or implementation related to that discussion.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for loading PyTorch models from NPZ files into Lux.jl. It correctly updates the Project.toml to include the NPZ dependency and adds a new extension module NPZExt. The NPZExt.jl file defines a DataLoader abstraction and an NPZLoader to handle the loading process. However, there are critical issues related to undefined variables and significant limitations due to hardcoded model architectures, which need to be addressed for this feature to be generally usable.

Comment on lines +51 to +52
lux_layer_params.weight .= Float32.(weights[py_name * ".weight"])
lux_layer_params.bias .= Float32.(weights[py_name * ".bias"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable weights is used to access the loaded parameters, but it is not defined within the Lux.load function. The result of NPZ.npzread(filepath) is assigned to the file variable. This will lead to an UndefVarError when trying to access weights[py_name * ".weight"] and weights[py_name * ".bias"].

The correct variable to use here should be file.

        lux_layer_params.weight .= Float32.(file[py_name * ".weight"])
        lux_layer_params.bias   .= Float32.(file[py_name * ".bias"])

Comment on lines +23 to +30
# TODO don't hardcode
layer_map = [
"embed_layer" => :layer_1,
"layer_1" => :layer_2,
"layer_2" => :layer_4,
"layer_3" => :layer_6,
"output_layer" => :layer_7
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The layer_map is currently hardcoded for a very specific model architecture. This significantly limits the utility of the NPZLoader as it can only load models that precisely match this predefined mapping. For a general PyTorch model loading feature, this mapping should ideally be dynamic, inferred from the NPZ file metadata, or provided as a configurable input.

This hardcoding makes the loader non-reusable for other PyTorch models.

Comment on lines +33 to +42
# TODO don't hardcode
model = Chain(
Dense(13 => 32, leakyrelu),
Dense(32 => 64, leakyrelu),
Dropout(0.2),
Dense(64 => 64, leakyrelu),
Dropout(0.1),
Dense(64 => 32, leakyrelu),
Dense(32 => 1)
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the layer_map, the model architecture itself is hardcoded. This means the NPZLoader can only load weights into this exact Chain structure. To support loading diverse PyTorch models, the Lux model architecture should be constructed dynamically based on information extracted from the NPZ file or provided by the user, rather than being fixed.

This hardcoding prevents the loader from being a general-purpose utility.

Comment on lines +11 to +13
@kwdef struct NPZLoader <: PytorchLoader
dummy_field::Bool = true
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The NPZLoader struct includes a dummy_field that is always true and doesn't appear to serve any configuration purpose. If there are no specific fields required for NPZLoader configuration, it can be defined as a simple struct without fields to improve clarity and avoid unnecessary data.

Consider defining it as struct NPZLoader <: PytorchLoader end if no fields are needed.

struct NPZLoader <: PytorchLoader end

end

# fallback
Lux.load(::DataLoader, args...) = error("DataLoader $(typeof(args[1])) not implemented.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback error message error("DataLoader $(typeof(args[1])) not implemented.") might be less informative than intended. If args contains multiple arguments, args[1] might not always be the DataLoader instance itself. It would be clearer to explicitly refer to the type of the DataLoader argument N.

For example, error("DataLoader $(typeof(N)) not implemented.") would be more precise.

Lux.load(N::DataLoader, args...) = error("DataLoader $(typeof(N)) not implemented.")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant