@@ -842,21 +842,18 @@ def driver(
842842 if mass_matrix <= 0.0 :
843843 raise LoggedValueError ("mass_matrix must be positive" )
844844 mass_matrix = jnp .eye (n_dims ) * float (mass_matrix )
845- elif isinstance (mass_matrix , list ):
846- mass_matrix = jnp .array (mass_matrix )
847- if mass_matrix .ndim > 2 :
848- raise LoggedValueError ("mass_matrix must be 1D or 2D array" )
849- _shape = mass_matrix .shape
850- if _shape != (n_dims , n_dims ) and _shape != (n_dims ,):
851- raise LoggedValueError (
852- f"mass_matrix must be of shape ({ n_dims } , { n_dims } ) or ({ n_dims } ,), got { _shape } "
853- )
854- if _shape == (n_dims ,):
855- if jnp .any (mass_matrix <= 0 ):
856- raise LoggedValueError (
857- "mass_matrix diagonal elements must be positive"
858- )
859- mass_matrix = jnp .diag (mass_matrix )
845+
846+ if mass_matrix .ndim > 2 :
847+ raise LoggedValueError ("mass_matrix must be 1D or 2D array" )
848+ _shape = mass_matrix .shape
849+ if _shape != (n_dims , n_dims ) and _shape != (n_dims ,):
850+ raise LoggedValueError (
851+ f"mass_matrix must be of shape ({ n_dims } , { n_dims } ) or ({ n_dims } ,), got { _shape } "
852+ )
853+ if _shape == (n_dims ,):
854+ if jnp .any (mass_matrix <= 0 ):
855+ raise LoggedValueError ("mass_matrix diagonal elements must be positive" )
856+ mass_matrix = jnp .diag (mass_matrix )
860857
861858 bundle = Local_Global_Sampler_Bundle (
862859 rng_key = self .rng_key ,
0 commit comments