Skip to content

Commit 1bb4055

Browse files
authored
Merge pull request #56 from QTC-UMD/parameter_checks
Scannable Parameter checks
2 parents 6924cfd + 07b1b7c commit 1bb4055

3 files changed

Lines changed: 39 additions & 6 deletions

File tree

docs/source/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Improvements
1414
density matrices into the complex basis with ground state present.
1515
- Made project `uv` compatible and added installation docs describing how to use it.
1616
- Updated license metadata to follow PEP 639 standard
17+
- Improve scannable parameter handling so that all sequences are saved to the graph as numpy arrays.
18+
Also ensures proper handling of length-1 and length-0 sequences.
1719

1820
Bug Fixes
1921
+++++++++

src/rydiqule/sensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import itertools
1010

1111
from .sensor_utils import (ScannableParameter, CouplingDict, State, States, StateSpec, StateSpecs, TimeFunc,
12-
match_states, _squeeze_dims, expand_statespec, state_tuple_to_str)
12+
match_states, _squeeze_dims, expand_statespec, state_tuple_to_str, process_scannable_parameter)
1313
from .exceptions import RydiquleError, CouplingNotAllowedError
1414
from .exceptions import RWAWarning, PopulationNotConservedWarning, RydiquleWarning, debug_state
1515

@@ -416,6 +416,7 @@ def add_single_energy_shift(self, state: State, shift: ScannableParameter, label
416416
raise RydiquleError(f"state {state} is not a node on the graph")
417417

418418
self._remove_edge_data((state, state), kind="coherent")
419+
shift = process_scannable_parameter(shift)
419420
self.couplings.add_edge(state, state, e_shift=shift, label=label)
420421
if debug_state():
421422
print(f' Added energy shift for {state}')
@@ -856,8 +857,8 @@ def add_single_coupling(
856857
field_params_trimmed = {k:v for k,v in field_params.items() if v is not None}
857858

858859
full_edge_data = {
859-
param: np.array(val)
860-
if param in self.scannable_parameters and hasattr(val, "__len__")
860+
param: process_scannable_parameter(val)
861+
if param in self.scannable_parameters
861862
else val
862863
for (param, val) in {**field_params_trimmed, **extra_kwargs}.items()
863864
}
@@ -1718,8 +1719,7 @@ def add_single_decoherence(self, states: States, gamma: ScannableParameter,
17181719

17191720
states = self._states_valid(states)
17201721
# coerce gamma to numpy array if a sequence
1721-
if isinstance(gamma, Sized):
1722-
gamma = np.array(gamma)
1722+
gamma = process_scannable_parameter(gamma)
17231723

17241724
gamma_full = decoherent_cc*gamma
17251725
if np.all(gamma_full==0.0):

src/rydiqule/sensor_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import scipy.constants
1212
from scipy.constants import hbar, e
1313

14-
from typing import Dict, Tuple, Union, List, Callable, TYPE_CHECKING
14+
from typing import Dict, Tuple, Union, List, Callable, Sequence, TYPE_CHECKING
1515
if TYPE_CHECKING:
1616
# only import when type checking, avoid circular import
1717
from .sensor import Sensor
@@ -1036,6 +1036,37 @@ def state_tuple_to_str(states:States) -> str:
10361036
return "(" + ",".join([str(s) for s in states]) + ")"
10371037

10381038

1039+
def process_scannable_parameter(val: ScannableParameter) -> Union[np.ndarray, float]:
1040+
"""Ensures that scannable parameters are coerced to numpy arrays.
1041+
1042+
If the parameter only has length of one, content is extracted.
1043+
1044+
Parameters
1045+
----------
1046+
val : ScannableParameter
1047+
Scannable parameter to process.
1048+
1049+
Returns
1050+
-------
1051+
float or numpy.ndarray
1052+
1053+
Raises
1054+
------
1055+
RydiquleError:
1056+
Raised if passed an empty sequence
1057+
"""
1058+
1059+
if isinstance(val, (Sequence, np.ndarray)):
1060+
if len(val) > 1:
1061+
return np.array(val)
1062+
elif len(val) == 1:
1063+
return val[0]
1064+
else:
1065+
raise RydiquleError('Length-0 sequence passed as scannable parameter')
1066+
else:
1067+
return val
1068+
1069+
10391070
def _validate_sols(sols) -> np.ndarray:
10401071
"""Helper function to validate that solutions are of an appropriate type.
10411072
There are 3 outcomes:

0 commit comments

Comments
 (0)