Skip to content

Load PyTorch model weights into Lux #1657

@gregrmunday

Description

@gregrmunday

Hi! As part of some work on SpeedWeather.jl#959 I'm doing with @milankl, I've needed to load model weights from a PyTorch-trained model into Lux (just a simple feed-forward network).

After I save the weights from PyTorch into a .npz, I can load them into Lux using

using Lux, Random, NPZ

rng = Random.default_rng()
dense_model = Chain(
        Dense(13 => 32, leakyrelu),
        Dense(32 => 64, leakyrelu),
        Dropout(0.2),
        Dense(64 => 64, leakyrelu),
        Dropout(0.1),
        Dense(64 => 32, leakyrelu),
        Dense(32 => 1)
    )

parameters, states = Lux.setup(rng, dense_model)
weights = npzread("model_weights.npz")

layer_map = [
    "embed_layer"  => :layer_1,
    "layer_1"      => :layer_2,
    "layer_2"      => :layer_4,
    "layer_3"      => :layer_6,
    "output_layer" => :layer_7
]
for (py_name, lux_sym) in layer_map
    lux_layer_params = getproperty(parameters, lux_sym)
    
    lux_layer_params.weight .= Float32.(weights[py_name * ".weight"])
    lux_layer_params.bias   .= Float32.(weights[py_name * ".bias"])
end

and from there use the model like

test_states = Lux.testmode(states)
model = (u, p, s) -> first(Lux.apply(dense_model, u, p, s))
y = model(zeros(Float32, 13), parameters, test_states)

I'm basically wondering if there's a better way to do the above, and if not would it be worth writing a more simple load function into Lux? It feels like this could be quite a common requirement, especially for people coming from PyTorch and looking for an easy bridge between the two!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions