Description
Problem
For a model with QuantumModule
, torch.save(model.state_dict(), "model.pt") and model.load_state_dict(torch.load("model.pt")) may not work because state keys are lazily created during the forwarding process.
Example
I used the Model1
in Quantum Convolution (Quanvolution) example. The detailed code is below:
import torchquantum as tq
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torchquantum.dataset import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchquantum.layer import U3CU3Layer0
class TrainableQuanvFilter(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.encoder = tq.GeneralEncoder(
[
{"input_idx": [0], "func": "ry", "wires": [0]},
{"input_idx": [1], "func": "ry", "wires": [1]},
{"input_idx": [2], "func": "ry", "wires": [2]},
{"input_idx": [3], "func": "ry", "wires": [3]},
]
)
self.arch = {"n_wires": self.n_wires, "n_blocks": 5, "n_layers_per_block": 2}
self.q_layer = U3CU3Layer0(self.arch)
self.measure = tq.MeasureAll(tq.PauliZ)
def forward(self, x, use_qiskit=False):
bsz = x.shape[0]
qdev = tq.QuantumDevice(self.n_wires, bsz=bsz, device=x.device)
x = F.avg_pool2d(x, 6).view(bsz, 4, 4)
size = 4
stride = 2
x = x.view(bsz, size, size)
data_list = []
for c in range(0, size, stride):
for r in range(0, size, stride):
data = torch.transpose(
torch.cat(
(x[:, c, r], x[:, c, r + 1], x[:, c + 1, r], x[:, c + 1, r + 1])
).view(4, bsz),
0,
1,
)
if use_qiskit:
data = self.qiskit_processor.process_parameterized(
qdev, self.encoder, self.q_layer, self.measure, data
)
else:
self.encoder(qdev, data)
self.q_layer(qdev)
data = self.measure(qdev)
data_list.append(data.view(bsz, 4))
# transpose to (bsz, channel, 2x2)
result = torch.transpose(
torch.cat(data_list, dim=1).view(bsz, 4, 4), 1, 2
).float()
return result
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.qf = TrainableQuanvFilter()
self.linear = torch.nn.Linear(16, 4)
def forward(self, x, use_qiskit=False):
x = x.view(-1, 28, 28)
x = self.qf(x)
x = x.reshape(-1, 16)
x = self.linear(x)
return F.log_softmax(x, -1)
def train(dataflow, model, device, optimizer):
for feed_dict in dataflow["train"]:
inputs = feed_dict["image"].to(device)
targets = feed_dict["digit"].to(device)
outputs = model(inputs)
loss = F.nll_loss(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss: {loss.item()}", end="\r")
def valid_test(dataflow, split, model, device, qiskit=False):
target_all = []
output_all = []
with torch.no_grad():
for feed_dict in dataflow[split]:
inputs = feed_dict["image"].to(device)
targets = feed_dict["digit"].to(device)
outputs = model(inputs, use_qiskit=qiskit)
target_all.append(targets)
output_all.append(outputs)
target_all = torch.cat(target_all, dim=0)
output_all = torch.cat(output_all, dim=0)
_, indices = output_all.topk(1, dim=1)
masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
size = target_all.shape[0]
corrects = masks.sum().item()
accuracy = corrects / size
loss = F.nll_loss(output_all, target_all).item()
print(f"{split} set accuracy: {accuracy}")
print(f"{split} set loss: {loss}")
return accuracy, loss
n_epochs = 1
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
dataset = MNIST(
root="./mnist_data",
train_valid_split_ratio=[0.9, 0.1],
digits_of_interest=[0, 1, 2, 3],
n_test_samples=300,
n_train_samples=500,
)
dataflow = dict()
for split in dataset:
sampler = torch.utils.data.RandomSampler(dataset[split])
dataflow[split] = torch.utils.data.DataLoader(
dataset[split],
batch_size=10,
sampler=sampler,
num_workers=8,
pin_memory=True,
)
device = torch.device("cpu")
model = Model().to(device)
# print(f"training model...")
# optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
# for epoch in range(1, n_epochs + 1):
# # train
# print(f"Epoch {epoch}:")
# train(dataflow, model, device, optimizer)
# print(optimizer.param_groups[0]["lr"])
# # valid
# accu, loss = valid_test(dataflow, "test", model, device)
# scheduler.step()
The save and load example in Save and Load QNN models may not directly work. Without commenting on the training part, I trained the model and saved the state_dict with:
torch.save(model.state_dict(), "model.pt")
When I tried to load a newly created model, the model may not have some keys in state_dict because these keys are lazily created. For example, load a new model with:
model2 = Model().to(device)
model2.load_state_dict(torch.load("model.pt"))
The error is:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[4], [line 2](vscode-notebook-cell:?execution_count=4&line=2)
[1](vscode-notebook-cell:?execution_count=4&line=1) model2 = Model().to(device)
----> [2](vscode-notebook-cell:?execution_count=4&line=2) model2.load_state_dict(torch.load("model.pt"))
File [d:\Programming\Anaconda3\envs\torchquantum\lib\site-packages\torch\nn\modules\module.py:2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153), in Module.load_state_dict(self, state_dict, strict, assign)
[2148](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2148) error_msgs.insert(
[2149](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2149) 0, 'Missing key(s) in state_dict: {}. '.format(
[2150](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2150) ', '.join(f'"{k}"' for k in missing_keys)))
[2152](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2152) if len(error_msgs) > 0:
-> [2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[2154](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2154) self.__class__.__name__, "\n\t".join(error_msgs)))
[2155](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2155) return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for Model:
Unexpected key(s) in state_dict: "qf.q_layer.q_device.state", "qf.q_layer.q_device.states".
Potential Solution
The cause is that during the forwarding process, the model may create new features.
For example, in the above example, in TrainableQuanvFilter
, self.q_layer
is created inside __init__
with self.q_layer = U3CU3Layer0(self.arch)
. torchquantum/layer/layers/u3_layer.py
U3CU3Layer0
is inherited from LayerTemplate0
and its forward()
method is also inherited from LayerTemplate0
. Inside forward()
, a new feature is appended to the object, and this will only be appended when forwarding the model: torchquantum/layer/layers/layers.py
@tq.static_support
def forward(self, q_device: tq.QuantumDevice):
self.q_device = q_device
for k in range(len(self.layers_all)):
self.layers_all[k](q_device)
One solution is to finish creating all features during __init__
, but I am not familiar with torchquantum's design principle. Since forward()
requires q_device as an input, and this input is to be assigned to the feature, it may be designed to be used for many devices. So this change may require a large interface change, and may need to create one layer object for each device instead of reusing only one layer object.
Another way is every time to load the state_dict, forward the model once. This may be time-consuming when the model is large.
Related Issues
These issues may be related to this issue.
#210 is not resolved yet.
#49 provided the save and load example, but as explained, this will not work. In the example, it directly saves and loads, so keys in state_dict are not missing.