Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
try:
import torch
TORCH_ENABLED = True

def todevice_kwargs():
if torch.cuda.device_count() > 0:
return {'dtype': torch.float64, 'device': 'cuda:0'}
elif torch.mps.device_count() > 0:
return {'dtype': torch.float32, 'device': 'mps:0'}
elif torch.xpu.device_count() > 0:
return {'dtype': torch.float64, 'device': 'xpu:0'}
else:
return {'dtype': torch.float64, 'device': -1}
DEVICE_KWARGS = todevice_kwargs()

except ImportError:
TORCH_ENABLED = False
pass
Expand Down Expand Up @@ -71,10 +83,10 @@ class StatelessModel:
the sophiciated machinery in TorchForwardSimulator's base class.
"""

def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArrayLayout):
def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArrayLayout, use_gpu: bool):
circuits = []
self.outcome_probs_dim = 0
#TODO: Refactor this to use the bulk_expand_instruments_and_separate_povm codepath
# TODO: Refactor this to use the bulk_expand_instruments_and_separate_povm codepath
for _, circuit, outcomes in layout.iter_unique_circuits():
expanded_circuits = model.expand_instruments_and_separate_povm(circuit, outcomes)
if len(expanded_circuits) > 1:
Expand All @@ -84,12 +96,14 @@ def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArra
circuits.append(c)
self.outcome_probs_dim += c.outcome_probs_dim
self.circuits = circuits
self.use_gpu = use_gpu

# We need to verify assumptions on what layout.iter_unique_circuits() returns.
# Looking at the implementation of that function, the assumptions can be
# framed in terms of the "layout._element_indicies" dict.
eind = layout._element_indices
assert isinstance(eind, dict)
assert len(eind) > 0
items = iter(eind.items())
k_prev, v_prev = next(items)
assert k_prev == 0
Expand All @@ -105,13 +119,16 @@ def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArra
for lbl, obj in model._iter_parameterized_objs():
assert isinstance(obj, Torchable), f"{type(obj)} does not subclass {Torchable}."
param_type = type(obj)
param_data = (lbl, param_type) + (obj.stateless_data(),)
sld = obj.stateless_data()
if self.use_gpu:
sld = tuple((i.to(**DEVICE_KWARGS) if isinstance(i, torch.Tensor) else i) for i in sld)
param_data = (lbl, param_type) + (sld,)
self.param_metadata.append(param_data)
self.params_dim = None
# ^ That's set in get_free_params.

self.default_to_reverse_ad = None
# ^ That'll be set to a boolean the next time that get_free_params is called.
self.params_dim = []
self.free_param_sizes = []
self.default_to_reverse_ad = False
# ^ Those are set in get_free_params
return

def get_free_params(self, model: ExplicitOpModel) -> Tuple[torch.Tensor]:
Expand All @@ -123,6 +140,7 @@ def get_free_params(self, model: ExplicitOpModel) -> Tuple[torch.Tensor]:
to StatelessModel.__init__(...). We raise an error if an inconsistency is detected.
"""
free_params = []
free_param_sizes = []
prev_idx = 0
for i, (lbl, obj) in enumerate(model._iter_parameterized_objs()):
gpind = obj.gpindices_as_array()
Expand All @@ -143,7 +161,11 @@ def get_free_params(self, model: ExplicitOpModel) -> Tuple[torch.Tensor]:
call to get_torch_bases would silently fail, so we're forced to raise an error here.
"""
raise ValueError(message)
if self.use_gpu:
vec = vec.to(**DEVICE_KWARGS)
free_params.append(vec)
free_param_sizes.append(vec_size)
self.free_param_sizes = free_param_sizes
self.params_dim = prev_idx
self.default_to_reverse_ad = self.outcome_probs_dim < self.params_dim
return tuple(free_params)
Expand All @@ -157,7 +179,7 @@ def get_torch_bases(self, free_params: Tuple[torch.Tensor]) -> Dict[Label, torch
----
If you want to use the returned dict to build a PyTorch Tensor that supports the
.backward() method, then you need to make sure that fp.requires_grad is True for all
fp in free_params. This can be done by calling fp._requires_grad(True) before calling
fp in free_params. This can be done by calling fp.requires_grad_(True) before calling
this function.
"""
assert len(free_params) == len(self.param_metadata)
Expand Down Expand Up @@ -202,8 +224,9 @@ def circuit_probs_from_free_params(self, *free_params: Tuple[torch.Tensor], enab
"""
if enable_backward:
for fp in free_params:
fp._requires_grad(True)
torch_bases = self.get_torch_bases(free_params)
fp.requires_grad_(True)

torch_bases = self.get_torch_bases(free_params) # type: ignore
probs = self.circuit_probs_from_torch_bases(torch_bases)
return probs

Expand All @@ -215,15 +238,16 @@ class TorchForwardSimulator(ForwardSimulator):

ENABLED = TORCH_ENABLED

def __init__(self, model : Optional[ExplicitOpModel] = None):
def __init__(self, model : Optional[ExplicitOpModel] = None, use_gpu=True):
if not self.ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
self.use_gpu = use_gpu
super(ForwardSimulator, self).__init__(model)

def _bulk_fill_probs(self, array_to_fill, layout, split_model = None) -> None:
if split_model is None:
slm = StatelessModel(self.model, layout)
slm = StatelessModel(self.model, layout, self.use_gpu)
free_params = slm.get_free_params(self.model)
torch_bases = slm.get_torch_bases(free_params)
else:
Expand All @@ -234,7 +258,7 @@ def _bulk_fill_probs(self, array_to_fill, layout, split_model = None) -> None:
return

def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None:
slm = StatelessModel(self.model, layout)
slm = StatelessModel(self.model, layout, self.use_gpu)
# ^ TODO: figure out how to safely recycle StatelessModel objects from one
# call to another. The current implementation is wasteful if we need to
# compute many jacobians without structural changes to layout or self.model.
Expand All @@ -247,18 +271,17 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None:

argnums = tuple(range(len(slm.param_metadata)))
if slm.default_to_reverse_ad:
# Then slm.circuit_probs_from_free_params will automatically construct the
# torch_base dict to support reverse-mode AD.
J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums)
J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums) # type: ignore
else:
# Then slm.circuit_probs_from_free_params will automatically skip the extra
# steps needed for torch_base to support reverse-mode AD.
J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums)
J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums) # type: ignore
# ^ Note that this _bulk_fill_dprobs function doesn't accept parameters that
# could be used to override the default behavior of the StatelessModel. If we
# have a need to override the default in the future then we'd need to override
# the ForwardSimulator function(s) that call self._bulk_fill_dprobs(...).

# if self.use_gpu:
# J_func = make_fx(J_func, tracing_mode='fake')
# ^ see https://github.com/pytorch/pytorch/issues/152701#issuecomment-2847838362
J_val = J_func(*free_params)
J_val = torch.column_stack(J_val)
array_to_fill[:] = J_val.cpu().detach().numpy()
Expand Down
17 changes: 8 additions & 9 deletions pygsti/modelmembers/operations/fulltpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from pygsti.modelmembers.torchable import Torchable as _Torchable




class FullTPOp(_DenseOperator, _Torchable):
"""
A trace-preserving operation matrix.
Expand Down Expand Up @@ -169,15 +167,16 @@ def from_vector(self, v, close=False, dirty_value=True):
self._ptr_has_changed() # because _rep.base == _ptr (same memory)
self.dirty = dirty_value

def stateless_data(self) -> Tuple[int]:
return (self.dim,)

@staticmethod
def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor:
dim = sd[0]
def stateless_data(self) -> Tuple[int, _torch.Tensor]:
dim = self.dim
t_const = _torch.zeros(size=(1, dim), dtype=_torch.double)
t_const[0,0] = 1.0
t_param_mat = t_param.reshape((dim - 1, dim))
return (dim, t_const)

@staticmethod
def torch_base(sd: Tuple[int, _torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor:
dim, t_const = sd
t_param_mat = t_param.view(dim - 1, dim)
t = _torch.row_stack((t_const, t_param_mat))
return t

Expand Down
6 changes: 6 additions & 0 deletions pygsti/modelmembers/povms/conjugatedeffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __setitem__(self, key, val):
ret = self.columnvec.__setitem__(key, val)
self._ptr_has_changed()
return ret

def __getstate__(self):
return self.__dict__

def __setstate__(self, d):
self.__dict__.update(d)

def __getattr__(self, attr):
#use __dict__ so no chance for recursive __getattr__
Expand Down
23 changes: 11 additions & 12 deletions pygsti/modelmembers/povms/tppovm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,29 +102,28 @@ def to_vector(self):
vec = _np.concatenate(effect_vecs)
return vec

def stateless_data(self) -> Tuple[int, _np.ndarray]:
def stateless_data(self) -> Tuple[int, _torch.Tensor, int]:
num_effects = len(self)
complement_effect = self[self.complement_label]
identity = complement_effect.identity.to_vector()
return (num_effects, identity)

@staticmethod
def torch_base(sd: Tuple[int, _np.ndarray], t_param: _torch.Tensor) -> _torch.Tensor:
num_effects, identity = sd
identity = identity.reshape((1, -1)) # make into a row vector
t_identity = _torch.from_numpy(identity)

dim = identity.size

first_basis_vec = _np.zeros(dim)
first_basis_vec[0] = dim ** 0.25
first_basis_vec = _np.zeros((1,dim))
first_basis_vec[0,0] = dim ** 0.25
TOL = 1e-15 * _np.sqrt(dim)
if _np.linalg.norm(first_basis_vec - identity) > TOL:
# Don't error out. The documentation for the class
# clearly indicates that the meaning of "identity"
# can be nonstandard.
warnings.warn('Unexpected normalization!')
return (num_effects, t_identity, dim)

identity = identity.reshape((1, -1)) # make into a row vector
t_identity = _torch.from_numpy(identity)
t_param_mat = t_param.reshape((num_effects - 1, dim))
@staticmethod
def torch_base(sd: Tuple[int, _torch.Tensor, int], t_param: _torch.Tensor) -> _torch.Tensor:
num_effects, t_identity, dim = sd
t_param_mat = t_param.view(num_effects - 1, dim)
t_func = t_identity - t_param_mat.sum(axis=0, keepdim=True)
t = _torch.row_stack((t_param_mat, t_func))
return t
6 changes: 6 additions & 0 deletions pygsti/modelmembers/states/densestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def __setitem__(self, key, val):
ret = self.columnvec.__setitem__(key, val)
self._ptr_has_changed()
return ret

def __getstate__(self):
return self.__dict__

def __setstate__(self, d):
self.__dict__.update(d)

def __getattr__(self, attr):
#use __dict__ so no chance for recursive __getattr__
Expand Down
11 changes: 6 additions & 5 deletions pygsti/modelmembers/states/tpstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,14 @@ def from_vector(self, v, close=False, dirty_value=True):
self._ptr_has_changed()
self.dirty = dirty_value

def stateless_data(self) -> Tuple[int]:
return (self.dim,)
def stateless_data(self) -> Tuple[_torch.Tensor]:
dim = self.dim
t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double)
return (t_const,)

@staticmethod
def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor:
dim = sd[0]
t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double)
def torch_base(sd: Tuple[_torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor:
t_const = sd[0]
t = _torch.concat((t_const, t_param))
return t

Expand Down
50 changes: 50 additions & 0 deletions temporary_demos/function_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

import cProfile
import pstats
import io


def profile_function(func, subfunction_to_profile, *args, **kwargs, ):
"""
Profiles a given function and returns the time spent in that function.

Parameters:
- func: The function to run.
- subfunction_to_profile: the subfunction to extract stats for.
- *args: Positional arguments to pass to the function.
- **kwargs: Keyword arguments to pass to the function.


Returns:
- A dictionary containing the total time spent in the function and the number of calls.
"""
# Create a profiler
profiler = cProfile.Profile()
profiler.enable() # Start profiling

# Call the function with the provided arguments
func(*args, **kwargs)

profiler.disable() # Stop profiling

# Create a stream to hold the profiling results
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats() # Print the profiling results

# Parse the output to find the specific function's stats
function_name = subfunction_to_profile.__name__
function_stats = {}

for line in s.getvalue().splitlines():
#print(line)
if function_name in line:
parts = line.split()
# Extract the relevant statistics
function_stats['ncalls'] = int(parts[0]) # Number of calls
function_stats['tottime'] = float(parts[2]) # Total time spent in the function
function_stats['percall'] = float(parts[2]) / int(parts[0]) if int(parts[0]) > 0 else 0 # Time per call
function_stats['cumtime'] = float(parts[3]) # Cumulative time spent in the function
break

return function_stats
Loading
Loading