Skip to content

Commit 64ff390

Browse files
authored
Add dtype and backend support to optimizers and centralize RNG (#32)
- Update `load_tn` and `BaseOptimizer` to support `dtype` and `backend` parameters, allowing for better control over numerical types and integration with different backends via `autoray`. - Centralize random number generation by adding a shared _rng to BaseOptimizer, ensuring consistent seeding across optimizer implementations. - Refactor `finite_width` and `infinite_memory` simulated annealing optimizers to use the base class RNG instead of local instances.
1 parent 9c89eb8 commit 64ff390

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

tnco/app/app.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
from decimal import Decimal
2727
from importlib import import_module
2828
from pathlib import Path
29+
from random import Random
2930
from typing import Any, Dict, List, Optional, Tuple, Union
3031
from warnings import warn
3132

33+
import autoray as ar
3234
import more_itertools as mit
3335
from pathvalidate import ValidationError, validate_filepath
3436

@@ -152,6 +154,7 @@ def load_tn(obj: Any,
152154
output_index_token: Optional[str] = '*',
153155
sparse_index_token: Optional[str] = '/',
154156
atol: Optional[float] = 1e-5,
157+
dtype: Optional[any] = None,
155158
backend: Optional[str] = None,
156159
seed: Optional[int] = None,
157160
verbose: Optional[int] = False) -> TensorNetwork:
@@ -187,6 +190,7 @@ def load_tn(obj: Any,
187190
sparse_index_token: If 'obj' is a list of indices, the token to use to
188191
identify sparse inds.
189192
atol: Absolute tollerance when checking for hyper-indices.
193+
dtype: Type to use for arrays.
190194
backend: Backend to use to fuse arrays. See: `autoray.do`.
191195
seed: Seed to use.
192196
verbose: Verbose output.
@@ -298,7 +302,9 @@ def is_gate(x):
298302
# Get tensors
299303
ts_inds = list(obj.ts_inds)
300304
dims = obj.dims
301-
arrays = list(obj.arrays)
305+
arrays = list(None if a is
306+
None else ar.do('asarray', a, dtype=dtype, like=backend)
307+
for a in obj.arrays)
302308
tags = dict(obj.tags)
303309
ts_tags = list(obj.ts_tags)
304310
output_inds = obj.output_inds
@@ -484,6 +490,8 @@ def is_gate(x):
484490
decompose_hyper_inds=False,
485491
fuse=False,
486492
atol=atol,
493+
dtype=dtype,
494+
backend=backend,
487495
seed=seed,
488496
verbose=verbose)
489497

@@ -509,6 +517,8 @@ def is_gate(x):
509517
decompose_hyper_inds=False,
510518
fuse=False,
511519
atol=atol,
520+
dtype=dtype,
521+
backend=backend,
512522
seed=seed,
513523
verbose=verbose)
514524

@@ -533,6 +543,8 @@ def is_gate(x):
533543
decompose_hyper_inds=False,
534544
fuse=False,
535545
atol=atol,
546+
dtype=dtype,
547+
backend=backend,
536548
seed=seed,
537549
verbose=verbose)
538550

@@ -716,6 +728,7 @@ class BaseOptimizer:
716728
overwrite_output_file: If 'True', the 'output_filename' will be
717729
overwritten if it exists.
718730
atol: Absolute tollerance when checking for hyper-indices.
731+
dtype: Type to use for arrays.
719732
backend: Backend to use to fuse arrays. See: `autoray.do`.
720733
seed: Seed to use.
721734
verbose: Verbose output.
@@ -729,6 +742,7 @@ class BaseOptimizer:
729742
output_compression: Optional[str] = 'auto'
730743
overwrite_output_file: Optional[bool] = False
731744
atol: Optional[float] = 1e-5
745+
dtype: Optional[any] = None
732746
backend: Optional[str] = None
733747
seed: Optional[int] = None
734748
verbose: Optional[int] = False
@@ -739,6 +753,7 @@ def optimize(self, *args, **kwargs):
739753
def _load_tn(self, tn, **load_tn_options):
740754
return load_tn(tn,
741755
atol=self.atol,
756+
dtype=self.dtype,
742757
backend=self.backend,
743758
seed=self.seed,
744759
verbose=self.verbose,
@@ -754,6 +769,9 @@ def _dump_results(self, tn, res, **dump_results_options):
754769
**dump_results_options)
755770

756771
def __post_init__(self):
772+
# Initialize common rng
773+
self._rng = Random(self.seed)
774+
757775
# Check dumper
758776
self._dump_results(None, None, check_only=True)
759777

@@ -768,6 +786,7 @@ def Optimizer(method: Optional[str] = 'sa',
768786
output_compression: Optional[str] = 'auto',
769787
overwrite_output_file: Optional[bool] = False,
770788
atol: Optional[float] = 1e-5,
789+
dtype: Optional[any] = None,
771790
backend: Optional[str] = None,
772791
seed: Optional[int] = None,
773792
verbose: Optional[int] = False) -> BaseOptimizer:
@@ -803,6 +822,7 @@ def Optimizer(method: Optional[str] = 'sa',
803822
overwrite_output_file: If 'True', the 'output_filename' will be
804823
overwritten if it exists.
805824
atol: Absolute tollerance when checking for hyper-indices.
825+
dtype: Type to use for arrays.
806826
backend: Backend to use to fuse arrays. See: `autoray.do`.
807827
seed: Seed to use.
808828
verbose: Verbose output.

tnco/app/finite_width/sa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import json
1717
import operator as op
1818
from dataclasses import dataclass
19-
from random import Random
2019
from sys import stderr
2120
from time import perf_counter
2221
from typing import Any, FrozenSet, Iterable, List, Optional, Tuple, Union
@@ -149,7 +148,7 @@ def optimize(self,
149148
tn = self._load_tn(tn, **load_tn_options)
150149

151150
# Initialize random generator
152-
rng = Random(self.seed)
151+
rng = self._rng
153152

154153
# Check 'n_steps'
155154
if n_steps is not None:

tnco/app/infinite_memory/sa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import json
1616
from dataclasses import dataclass
17-
from random import Random
1817
from sys import stderr
1918
from time import perf_counter
2019
from typing import Any, Iterable, List, Optional, Tuple, Union
@@ -131,7 +130,7 @@ def optimize(self,
131130
tn = self._load_tn(tn, **load_tn_options)
132131

133132
# Initialize random generator
134-
rng = Random(self.seed)
133+
rng = self._rng
135134

136135
# Check 'n_steps'
137136
if n_steps is not None:

0 commit comments

Comments
 (0)