Skip to content

Commit c3c6210

Browse files
committed
First attempt to build in causal symmetry hack to InflationProblem
1 parent fc59c81 commit c3c6210

File tree

1 file changed

+84
-7
lines changed

1 file changed

+84
-7
lines changed

inflation/InflationProblem.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self,
4848
nonclassical_intermediate_latents: Union[Tuple[str,...], List[str]]=tuple(),
4949
classical_intermediate_latents: Union[Tuple[str,...], List[str]]=tuple(),
5050
order: Union[Tuple[str,...], List[str]]=tuple(),
51+
really_just_one_source: bool=True,
5152
verbose=0):
5253
"""Class for encoding relevant details concerning the causal compatibility
5354
scenario.
@@ -121,6 +122,8 @@ def __init__(self,
121122
assert self.classical_intermediate_latents.isdisjoint(self.nonclassical_intermediate_latents), "An intermediate latent cannot be both classical and nonclassical."
122123
self.intermediate_latents = self.classical_intermediate_latents.union(self.nonclassical_intermediate_latents)
123124

125+
self.really_just_one_source = really_just_one_source
126+
124127
# Assign names to the visible variables
125128
names_have_been_set_yet = False
126129
if dag:
@@ -314,7 +317,7 @@ def __init__(self,
314317
# Determine if the inflation problem has a factorizing pair of parties.
315318
shared_sources = [np.all(np.vstack(pair), axis=0) for pair in
316319
combinations_with_replacement(self.hypergraph.T, 2)]
317-
just_one_copy = (self.inflation_level_per_source == 1)
320+
just_one_copy = np.asarray(self.inflation_level_per_source == 1)
318321
self.ever_factorizes = False
319322
for sources_are_shared in shared_sources:
320323
# If for some two parties, the sources that they share in common
@@ -356,7 +359,11 @@ def __init__(self,
356359
self._inflation_indices_hash = {op.tobytes(): i for i, op
357360
in enumerate(
358361
self._all_unique_inflation_indices)}
359-
self._inflation_indices_overlap = nb_overlap_matrix(
362+
if really_just_one_source:
363+
overlap_matrix = self.one_source_overlap_matrix
364+
else:
365+
overlap_matrix = nb_overlap_matrix
366+
self._inflation_indices_overlap = overlap_matrix(
360367
np.asarray(self._all_unique_inflation_indices, dtype=self._np_dtype))
361368

362369
# Create the measurements (formerly generate_parties)
@@ -437,7 +444,10 @@ def __init__(self,
437444
return_inverse=True, axis=0)
438445

439446
# Symmetries implied by the inflation
440-
self.symmetries = self.inflation_symmetries
447+
if really_just_one_source:
448+
self.symmetries = self.inflation_symmetries_one_source
449+
else:
450+
self.symmetries = self.inflation_symmetries
441451

442452

443453
@property
@@ -672,12 +682,12 @@ def _make_interpretation_hashable(op_as_dict):
672682
tuple(int(i) for i in op_as_dict["Setting as Tuple"]),
673683
int(op_as_dict["Outcome"]))
674684

675-
@staticmethod
676-
def _interpretation_to_name(op: dict, include_copy_indices=True) -> str:
685+
686+
def _interpretation_to_name(self, op: dict, include_copy_indices=True) -> str:
677687
op_as_str = op["Party"]
678688
if not op["Private Setting is Trivial"]:
679689
op_as_str += '_'+str(op["Private Setting"])
680-
if include_copy_indices:
690+
if include_copy_indices or self.really_just_one_source:
681691
if len(op["Relevant Copy Indices"]):
682692
copy_index_string = '^{'
683693
copy_index_string += ','.join(map(str,op["Relevant Copy Indices"].flat))
@@ -834,7 +844,7 @@ def factorize_monomial_2d(self,
834844
[3, 1, 4, 0, 0, 0],
835845
[3, 6, 6, 0, 0, 0],
836846
[3, 4, 5, 0, 0, 0]])
837-
>>> factorised = factorize_monomial(monomial_as_2darray)
847+
>>> factorised = InflationProblem.factorize_monomial_2d(monomial_as_2darray)
838848
[array([[1, 0, 1, 1, 0, 0]]),
839849
array([[1, 0, 3, 3, 0, 0]]),
840850
array([[2, 1, 0, 2, 0, 0],
@@ -1118,3 +1128,70 @@ def _all_possible_symmetries(self) -> np.ndarray:
11181128
self._setting_specific_outcome_relabelling_symmetries))
11191129
group_elements = group_elements_from_generators(group_generators)
11201130
return group_elements
1131+
1132+
###########################################################################
1133+
# FUNCTIONS PERTAINING TO ACTUALLY ONE SOURCE #
1134+
###########################################################################
1135+
@staticmethod
1136+
def exists_shared_source_modified(inf_indices1: np.ndarray,
1137+
inf_indices2: np.ndarray) -> bool:
1138+
common_sources = np.logical_and(inf_indices1, inf_indices2)
1139+
if not np.any(common_sources):
1140+
return False
1141+
return not set(inf_indices1[common_sources]).isdisjoint(set(inf_indices2[common_sources]))
1142+
1143+
1144+
def one_source_overlap_matrix(self, all_inflation_indxs: np.ndarray) -> np.ndarray:
1145+
n = len(all_inflation_indxs)
1146+
adj_mat = np.eye(n, dtype=bool)
1147+
for i in range(1, n):
1148+
inf_indices_i = all_inflation_indxs[i]
1149+
for j in range(i):
1150+
inf_indices_j = all_inflation_indxs[j]
1151+
if self.exists_shared_source_modified(inf_indices_i, inf_indices_j):
1152+
adj_mat[i, j] = True
1153+
adj_mat = np.logical_or(adj_mat, adj_mat.T)
1154+
return adj_mat
1155+
1156+
@cached_property
1157+
def inflation_symmetries_one_source(self) -> np.ndarray:
1158+
"""Calculates all the symmetries pertaining to the set of generating
1159+
monomials due to copy index relabelling. The new set of operators is a
1160+
permutation of the old. The function outputs a list of all permutations.
1161+
1162+
Returns
1163+
-------
1164+
numpy.ndarray[int]
1165+
The permutations of the lexicographic order implied by the inflation
1166+
symmetries.
1167+
"""
1168+
assert np.array_equiv(self.inflation_level_per_source,
1169+
self.inflation_level_per_source[0]), """
1170+
Only call this with uniform inflation level!"""
1171+
inf_level = max(self.inflation_level_per_source)
1172+
if inf_level>1:
1173+
permutation_failed = False
1174+
symmetries = []
1175+
identity_perm = np.arange(self._nr_operators, dtype=np.intc)
1176+
perms = format_permutations(list(
1177+
permutations(range(inf_level)))[1:])
1178+
all_sources_simultanous = np.arange(len(self.inflation_level_per_source))
1179+
for permutation in perms:
1180+
adjusted_ops = apply_source_perm(self._lexorder,
1181+
all_sources_simultanous,
1182+
permutation)
1183+
try:
1184+
new_order = np.fromiter(
1185+
(self._lexorder_lookup[op.tobytes()]
1186+
for op in adjusted_ops),
1187+
dtype=np.intc
1188+
)
1189+
symmetries.append(new_order)
1190+
except KeyError:
1191+
permutation_failed = True
1192+
if permutation_failed and (self.verbose > 0):
1193+
warn("The generating set is not closed under source swaps."
1194+
+ " Some symmetries will not be implemented.")
1195+
return np.unique(symmetries, axis=0)
1196+
return np.arange(self._nr_operators, dtype=np.intc)[np.newaxis]
1197+

0 commit comments

Comments
 (0)