Skip to content

Commit 03ee026

Browse files
authored
fix: streamline mass matrix validation in FlowMCBased class (#835)
1 parent a9749a9 commit 03ee026

1 file changed

Lines changed: 12 additions & 15 deletions

File tree

src/gwkokab/analysis/core/flowMC_based.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -842,21 +842,18 @@ def driver(
842842
if mass_matrix <= 0.0:
843843
raise LoggedValueError("mass_matrix must be positive")
844844
mass_matrix = jnp.eye(n_dims) * float(mass_matrix)
845-
elif isinstance(mass_matrix, list):
846-
mass_matrix = jnp.array(mass_matrix)
847-
if mass_matrix.ndim > 2:
848-
raise LoggedValueError("mass_matrix must be 1D or 2D array")
849-
_shape = mass_matrix.shape
850-
if _shape != (n_dims, n_dims) and _shape != (n_dims,):
851-
raise LoggedValueError(
852-
f"mass_matrix must be of shape ({n_dims}, {n_dims}) or ({n_dims},), got {_shape}"
853-
)
854-
if _shape == (n_dims,):
855-
if jnp.any(mass_matrix <= 0):
856-
raise LoggedValueError(
857-
"mass_matrix diagonal elements must be positive"
858-
)
859-
mass_matrix = jnp.diag(mass_matrix)
845+
846+
if mass_matrix.ndim > 2:
847+
raise LoggedValueError("mass_matrix must be 1D or 2D array")
848+
_shape = mass_matrix.shape
849+
if _shape != (n_dims, n_dims) and _shape != (n_dims,):
850+
raise LoggedValueError(
851+
f"mass_matrix must be of shape ({n_dims}, {n_dims}) or ({n_dims},), got {_shape}"
852+
)
853+
if _shape == (n_dims,):
854+
if jnp.any(mass_matrix <= 0):
855+
raise LoggedValueError("mass_matrix diagonal elements must be positive")
856+
mass_matrix = jnp.diag(mass_matrix)
860857

861858
bundle = Local_Global_Sampler_Bundle(
862859
rng_key=self.rng_key,

0 commit comments

Comments
 (0)