Skip to content

Commit eab0805

Browse files
committed
Certificate dictionaries should not contain entries where the coefficient is zero.
1 parent d68f7c3 commit eab0805

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

inflation/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,20 @@ def clean_coefficients(cert: Dict[str, float],
7979
if not cert:
8080
return cert
8181
chop_tol = np.abs(chop_tol)
82-
coeffs = np.asarray(list(cert.values()))
82+
coeffs = np.array(list(cert.values()))
83+
good_coeffs_places = (np.abs(coeffs) > chop_tol)
84+
new_coeffs = np.compress(good_coeffs_places, coeffs)
85+
new_keys = [k for i,k in enumerate(cert.keys()) if good_coeffs_places[i]]
8386
if chop_tol > 0:
8487
# Try to take the smallest nonzero one and make it 1, when possible
85-
normalising_factor = np.min(np.abs(coeffs[np.abs(coeffs) > chop_tol]))
88+
normalising_factor = np.min(np.abs(new_coeffs))
8689
else:
8790
# Take the largest nonzero one and make it 1
88-
normalising_factor = np.max(np.abs(coeffs[np.abs(coeffs) > chop_tol]))
89-
coeffs /= normalising_factor
90-
# Set to zero very small coefficients
91-
coeffs[np.abs(coeffs) <= chop_tol] = 0
91+
normalising_factor = np.max(np.abs(new_coeffs))
92+
new_coeffs /= normalising_factor
9293
# Round
93-
coeffs = np.round(coeffs, decimals=round_decimals)
94-
return dict(zip(cert.keys(), coeffs.flat))
94+
new_coeffs = np.round(new_coeffs, decimals=round_decimals)
95+
return dict(zip(new_keys, new_coeffs.flat))
9596

9697

9798
def eprint(*args, **kwargs):

0 commit comments

Comments
 (0)