2626from decimal import Decimal
2727from importlib import import_module
2828from pathlib import Path
29+ from random import Random
2930from typing import Any , Dict , List , Optional , Tuple , Union
3031from warnings import warn
3132
33+ import autoray as ar
3234import more_itertools as mit
3335from pathvalidate import ValidationError , validate_filepath
3436
@@ -152,6 +154,7 @@ def load_tn(obj: Any,
152154 output_index_token : Optional [str ] = '*' ,
153155 sparse_index_token : Optional [str ] = '/' ,
154156 atol : Optional [float ] = 1e-5 ,
157+ dtype : Optional [any ] = None ,
155158 backend : Optional [str ] = None ,
156159 seed : Optional [int ] = None ,
157160 verbose : Optional [int ] = False ) -> TensorNetwork :
@@ -187,6 +190,7 @@ def load_tn(obj: Any,
187190 sparse_index_token: If 'obj' is a list of indices, the token to use to
188191 identify sparse inds.
189192 atol: Absolute tollerance when checking for hyper-indices.
193+ dtype: Type to use for arrays.
190194 backend: Backend to use to fuse arrays. See: `autoray.do`.
191195 seed: Seed to use.
192196 verbose: Verbose output.
@@ -298,7 +302,9 @@ def is_gate(x):
298302 # Get tensors
299303 ts_inds = list (obj .ts_inds )
300304 dims = obj .dims
301- arrays = list (obj .arrays )
305+ arrays = list (None if a is
306+ None else ar .do ('asarray' , a , dtype = dtype , like = backend )
307+ for a in obj .arrays )
302308 tags = dict (obj .tags )
303309 ts_tags = list (obj .ts_tags )
304310 output_inds = obj .output_inds
@@ -484,6 +490,8 @@ def is_gate(x):
484490 decompose_hyper_inds = False ,
485491 fuse = False ,
486492 atol = atol ,
493+ dtype = dtype ,
494+ backend = backend ,
487495 seed = seed ,
488496 verbose = verbose )
489497
@@ -509,6 +517,8 @@ def is_gate(x):
509517 decompose_hyper_inds = False ,
510518 fuse = False ,
511519 atol = atol ,
520+ dtype = dtype ,
521+ backend = backend ,
512522 seed = seed ,
513523 verbose = verbose )
514524
@@ -533,6 +543,8 @@ def is_gate(x):
533543 decompose_hyper_inds = False ,
534544 fuse = False ,
535545 atol = atol ,
546+ dtype = dtype ,
547+ backend = backend ,
536548 seed = seed ,
537549 verbose = verbose )
538550
@@ -716,6 +728,7 @@ class BaseOptimizer:
716728 overwrite_output_file: If 'True', the 'output_filename' will be
717729 overwritten if it exists.
718730 atol: Absolute tollerance when checking for hyper-indices.
731+ dtype: Type to use for arrays.
719732 backend: Backend to use to fuse arrays. See: `autoray.do`.
720733 seed: Seed to use.
721734 verbose: Verbose output.
@@ -729,6 +742,7 @@ class BaseOptimizer:
729742 output_compression : Optional [str ] = 'auto'
730743 overwrite_output_file : Optional [bool ] = False
731744 atol : Optional [float ] = 1e-5
745+ dtype : Optional [any ] = None
732746 backend : Optional [str ] = None
733747 seed : Optional [int ] = None
734748 verbose : Optional [int ] = False
@@ -739,6 +753,7 @@ def optimize(self, *args, **kwargs):
739753 def _load_tn (self , tn , ** load_tn_options ):
740754 return load_tn (tn ,
741755 atol = self .atol ,
756+ dtype = self .dtype ,
742757 backend = self .backend ,
743758 seed = self .seed ,
744759 verbose = self .verbose ,
@@ -754,6 +769,9 @@ def _dump_results(self, tn, res, **dump_results_options):
754769 ** dump_results_options )
755770
756771 def __post_init__ (self ):
772+ # Initialize common rng
773+ self ._rng = Random (self .seed )
774+
757775 # Check dumper
758776 self ._dump_results (None , None , check_only = True )
759777
@@ -768,6 +786,7 @@ def Optimizer(method: Optional[str] = 'sa',
768786 output_compression : Optional [str ] = 'auto' ,
769787 overwrite_output_file : Optional [bool ] = False ,
770788 atol : Optional [float ] = 1e-5 ,
789+ dtype : Optional [any ] = None ,
771790 backend : Optional [str ] = None ,
772791 seed : Optional [int ] = None ,
773792 verbose : Optional [int ] = False ) -> BaseOptimizer :
@@ -803,6 +822,7 @@ def Optimizer(method: Optional[str] = 'sa',
803822 overwrite_output_file: If 'True', the 'output_filename' will be
804823 overwritten if it exists.
805824 atol: Absolute tollerance when checking for hyper-indices.
825+ dtype: Type to use for arrays.
806826 backend: Backend to use to fuse arrays. See: `autoray.do`.
807827 seed: Seed to use.
808828 verbose: Verbose output.
0 commit comments