Skip to content

Conversation

@yurekami
Copy link

Summary

  • Add reset_states(bsz) call in QFCModel.forward() to ensure quantum device state dimensions match the current batch size
  • The QuantumDevice is initialized with bsz=1 by default, but during training the batch size varies
  • Without resetting states, tensor dimension mismatch causes RuntimeError during matrix operations

Root Cause

The QuantumDevice is created in __init__ with default batch size 1:

self.q_device = tq.QuantumDevice(n_wires=self.n_wires)  # bsz defaults to 1

But in forward(), the input batch size varies (256 for full batches, smaller for final batch). The quantum operations expect the device states to match the input batch size.

Fix

Call reset_states(bsz) before using the quantum device:

def forward(self, x, use_qiskit=False):
    bsz = x.shape[0]
    x = F.avg_pool2d(x, 6).view(bsz, 16)
    
    # Reset quantum device states for current batch size
    self.q_device.reset_states(bsz)
    
    # ... rest of forward

Test plan

  • Run python examples/clifford_qnn/mnist_clifford_qnn.py to verify no RuntimeError
  • Verify training completes with varying batch sizes

Fixes #213

🤖 Generated with Claude Code

Add reset_states(bsz) call in QFCModel.forward() to ensure the
quantum device state dimensions match the current batch size.

The QuantumDevice is initialized with bsz=1 by default, but during
training the batch size varies (e.g., 256 for full batches, smaller
for final batch). Without resetting states, the tensor dimensions
mismatch causes RuntimeError during matrix operations.

Fixes mit-han-lab#213

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [256, 2] but got: [1, 2].(mnist_clifford_qnn.py)

1 participant