-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Description
Hi! As part of some work on SpeedWeather.jl#959 I'm doing with @milankl, I've needed to load model weights from a PyTorch-trained model into Lux (just a simple feed-forward network).
After I save the weights from PyTorch into a .npz, I can load them into Lux using
using Lux, Random, NPZ
rng = Random.default_rng()
dense_model = Chain(
Dense(13 => 32, leakyrelu),
Dense(32 => 64, leakyrelu),
Dropout(0.2),
Dense(64 => 64, leakyrelu),
Dropout(0.1),
Dense(64 => 32, leakyrelu),
Dense(32 => 1)
)
parameters, states = Lux.setup(rng, dense_model)
weights = npzread("model_weights.npz")
layer_map = [
"embed_layer" => :layer_1,
"layer_1" => :layer_2,
"layer_2" => :layer_4,
"layer_3" => :layer_6,
"output_layer" => :layer_7
]
for (py_name, lux_sym) in layer_map
lux_layer_params = getproperty(parameters, lux_sym)
lux_layer_params.weight .= Float32.(weights[py_name * ".weight"])
lux_layer_params.bias .= Float32.(weights[py_name * ".bias"])
endand from there use the model like
test_states = Lux.testmode(states)
model = (u, p, s) -> first(Lux.apply(dense_model, u, p, s))
y = model(zeros(Float32, 13), parameters, test_states)I'm basically wondering if there's a better way to do the above, and if not would it be worth writing a more simple load function into Lux? It feels like this could be quite a common requirement, especially for people coming from PyTorch and looking for an easy bridge between the two!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels