fix: Handle multi-parameter gate initialization correctly #313
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #232
This PR fixes the bug where multi-parameter gates (U2, U3, Rot, R, etc.) fail when passing 1D parameters like
[0.1, 0.2].The Problem
The error occurred because
gate_wrapperwas transforming 1D params incorrectly:[0.1, 0.2](shape[2]) →[2, 1]viaunsqueeze(-1)[0.1, 0.2](shape[2]) →[[0.1, 0.2]](shape[1, 2])Multi-parameter gate matrix functions access params as
params[:, 0]andparams[:, 1], requiring shape[batch, n_params].The Fix
Changed the dimension handling in
gate_wrapper:unsqueeze(0)instead ofunsqueeze(-1)- adds batch dimension at frontunsqueeze(0).unsqueeze(0)- adds both batch and param dimensionsAffected Gates
These gates now work correctly with 1D input for single-batch usage:
U2:tq.u2(qdev, [0], [phi, lam])U3:tq.u3(qdev, [0], [theta, phi, lam])Rot:tq.rot(qdev, [0], [phi, theta])R:tq.r(qdev, [0], [theta, phi])XXMinusYY,XXPlusYY: with[theta, beta]CU2,CU3,CRot: controlled versionsBackward Compatibility
The change maintains backward compatibility:
[[0.1, 0.2], [0.3, 0.4]]still work as before[0.1]→[[0.1]]still work correctlyTest Plan
tq.u2(qdev, [0], [0.1, 0.2])with bsz=1 - should work nowtq.u3(qdev, [0], [0.1, 0.2, 0.3])- should work[[0.1, 0.2], [0.3, 0.4]]still work🤖 Generated with Claude Code