Skip to content

Commit c17ce56

Browse files
chMoussaCharles MOUSSA
and
Charles MOUSSA
authored
[Bug] Repeated PSR (#46)
* add uuid to parametric and analog * make parameterize test + add extra op in differentiation * _param uuid in initialization * use suffix for gpsr * fix analog _gpsr suffix * fix _gpsr parametric * adding gate-wise psr for no shots * simplify computation no shots * add finite shots fwd notimplementederror for repeated case * change syntax for defjvp by using a shift argument in gates * change test for repeated params * tried gate ind but not working * repeated param case working * add test with jit * more docstr * fix shape jax random * use tuple of int for values * add chex * using chex * add spectral gap via gates * adding spectral gap from gates * also assert hessian * rm hessian parts * decorate hessian fcts with variants * rm forward_mode * rm also strenum * hessian too unstable * jax.hessian * rm hessian - still unstable * fixing shift add for hessian * add second derivatives in test * change naming align eigenvectors * replace chex with if for n_shots * refactor * remove n_shots None and small syntax changed * reput checkify * adding python 3.13 --------- Co-authored-by: Charles MOUSSA <[email protected]>
1 parent 52adc73 commit c17ce56

File tree

9 files changed

+455
-157
lines changed

9 files changed

+455
-157
lines changed

.github/workflows/run-tests-and-mypy.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
runs-on: ubuntu-latest
2323
strategy:
2424
matrix:
25-
python-version: ["3.9", "3.10", "3.11", "3.12"]
25+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
2626
steps:
2727
- name: Checkout main code and submodules
2828
uses: actions/checkout@v4

horqrux/analog.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ class _HamiltonianEvolution(Primitive):
2020
generator_name: str
2121
target: QubitSupport
2222
control: QubitSupport
23+
shift: float = 0.0
24+
25+
def __post_init__(self) -> None:
26+
super().__post_init__()
2327

2428
def _unitary(self, values: dict[str, Array] = dict()) -> Array:
25-
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))
29+
# note: GPSR trick when the same param_name is used in many operations
30+
time_val = values["time_evolution"] + self.shift
31+
return expm(values["hamiltonian"] * (-1j * time_val))
2632

2733

2834
def HamiltonianEvolution(

horqrux/api.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

33
from collections import Counter
4-
from typing import Any, Optional
4+
from typing import Any
55

66
import jax
77
import jax.numpy as jnp
88
from jax import Array
9-
from jax.experimental import checkify
109

1110
from horqrux.composite import Observable, OpSequence
1211
from horqrux.differentiation.ad import ad_expectation
@@ -15,7 +14,6 @@
1514
from horqrux.utils import (
1615
DensityMatrix,
1716
DiffMode,
18-
ForwardMode,
1917
State,
2018
num_qubits,
2119
probabilities,
@@ -96,8 +94,7 @@ def expectation(
9694
observables: list[Observable],
9795
values: dict[str, float],
9896
diff_mode: DiffMode = DiffMode.AD,
99-
forward_mode: ForwardMode = ForwardMode.EXACT,
100-
n_shots: Optional[int] = None,
97+
n_shots: int = 0,
10198
key: Any = jax.random.PRNGKey(0),
10299
) -> Array:
103100
"""Run 'state' through a sequence of 'gates' given parameters 'values'
@@ -109,8 +106,7 @@ def expectation(
109106
observables (list[Observable]): List of observables.
110107
values (dict[str, float]): Parameter values.
111108
diff_mode (DiffMode, optional): Differentiation mode. Defaults to DiffMode.AD.
112-
forward_mode (ForwardMode, optional): Type of forward method. Defaults to ForwardMode.EXACT.
113-
n_shots (Optional[int], optional): Number of shots. Defaults to None.
109+
n_shots (int): Number of shots. Defaults to 0 for no shots.
114110
key (Any, optional): Random key. Defaults to jax.random.PRNGKey(0).
115111
116112
Returns:
@@ -123,25 +119,21 @@ def expectation(
123119
raise TypeError("Adjoint does not support density matrices.")
124120
return adjoint_expectation(state, circuit, observables, values)
125121
elif diff_mode == DiffMode.GPSR:
126-
if forward_mode == ForwardMode.SHOTS:
127-
checkify.check(
128-
type(n_shots) is int and n_shots > 0,
129-
"Number of shots must be an integer for finite shots.",
130-
)
131-
# Type checking is disabled because mypy doesn't parse checkify.check.
132-
# type: ignore
133-
return finite_shots_fwd(
122+
if n_shots < 0:
123+
raise ValueError("The number of shots should be positive.")
124+
if n_shots == 0:
125+
return no_shots_fwd(
134126
state=state,
135127
gates=circuit.operations,
136128
observables=observables,
137129
values=values,
138-
n_shots=n_shots,
139-
key=key,
140130
)
141131
else:
142-
return no_shots_fwd(
132+
return finite_shots_fwd(
143133
state=state,
144134
gates=circuit.operations,
145135
observables=observables,
146136
values=values,
137+
n_shots=n_shots,
138+
key=key,
147139
)

0 commit comments

Comments
 (0)