Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
is_valid_contraction_tree)
from tnco.utils.tensor import \
decompose_hyper_inds as tensor_decompose_hyper_inds
from tnco.utils.tensor import get_einsum_path, svd
from tnco.utils.tensor import get_einsum_subscripts, svd
from tnco.utils.tn import contract
from tnco.utils.tn import decompose_hyper_inds as tn_decompose_hyper_inds
from tnco.utils.tn import (fuse, get_random_contraction_path,
Expand Down Expand Up @@ -1123,8 +1123,8 @@ def permutation(x):
output_inds = permutation(output_inds)

# Compute tensordot providing the output inds
array_c = np.einsum(get_einsum_path(inds_a, inds_b, output_inds), array_a,
array_b)
array_c = np.einsum(get_einsum_subscripts(inds_a, inds_b, output_inds),
array_a, array_b)

# Check if correct
np.testing.assert_allclose(
Expand All @@ -1151,8 +1151,8 @@ def fuse_(arrays, path, fused_inds):
# Contract
arrays.append(
Tensor(
np.einsum(get_einsum_path(tx.inds, ty.inds, iz), tx.data,
ty.data), iz))
np.einsum(get_einsum_subscripts(tx.inds, ty.inds, iz),
tx.data, ty.data), iz))

# Get final fused tensor
return arrays
Expand Down
2 changes: 1 addition & 1 deletion tnco/utils/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def contraction_path_(qs_A, qs_B):
xs_C = list(
its.chain(map(lambda q: (q, 'i'), all_qubits),
map(lambda q: (q, 'f'), all_qubits)))
return tensor_utils.get_einsum_path(xs_A, xs_B, xs_C)
return tensor_utils.get_einsum_subscripts(xs_A, xs_B, xs_C)

# Reshape unitaries
array_A = array_A.reshape((2,) * 2 * len(qubits_A))
Expand Down
6 changes: 3 additions & 3 deletions tnco/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from tnco.typing import Array, Index

__all__ = ['decompose_hyper_inds', 'get_einsum_path', 'svd']
__all__ = ['decompose_hyper_inds', 'get_einsum_subscripts', 'svd']


def is_diagonal(array: Array, /, *, atol: Optional[float] = 1e-8) -> bool:
Expand Down Expand Up @@ -140,8 +140,8 @@ def pad(xs):
return decompose_hyper_inds(array, inds, _hyper_inds=_hyper_inds)


def get_einsum_path(inds_a: Iterable[Index], inds_b: Iterable[Index],
output_inds: Iterable[Index], /) -> str:
def get_einsum_subscripts(inds_a: Iterable[Index], inds_b: Iterable[Index],
output_inds: Iterable[Index], /) -> str:
"""Return einsum path.

Return einsum path for the contraction 'inds_a @ inds_b -> output_inds'.
Expand Down
3 changes: 2 additions & 1 deletion tnco/utils/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ def contract(
# Append new tensor
if arrays is not None:
arrays.append(
do('einsum', tensor_utils.get_einsum_path(xs, ys, zs), ax, ay))
do('einsum', tensor_utils.get_einsum_subscripts(xs, ys, zs), ax,
ay))

# Append new indices
ts_inds.append(zs)
Expand Down