@@ -83,9 +83,9 @@ def __init__(
8383 [2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
8484 """
8585
86- assert (
87- thetas . shape [ 0 ] == xs . shape [ 0 ] == posterior_samples . shape [ 0 ]
88- ), "Number of samples must match"
86+ assert thetas . shape [ 0 ] == xs . shape [ 0 ] == posterior_samples . shape [ 0 ], (
87+ "Number of samples must match"
88+ )
8989
9090 # set observed data for classification
9191 self .theta_p = posterior_samples
@@ -283,9 +283,9 @@ def get_statistic_on_observed_data(
283283 Returns:
284284 L-C2ST statistic at `x_o`.
285285 """
286- assert (
287- self . trained_clfs is not None
288- ), "No trained classifiers found. Run `train_on_observed_data` first."
286+ assert self . trained_clfs is not None , (
287+ "No trained classifiers found. Run `train_on_observed_data` first."
288+ )
289289 _ , scores = self .get_scores (
290290 theta_o = theta_o ,
291291 x_o = x_o ,
@@ -372,9 +372,9 @@ def train_under_null_hypothesis(
372372 joint_q_perm [:, self .theta_q .shape [1 ] :],
373373 )
374374 else :
375- assert (
376- self . null_distribution is not None
377- ), "You need to provide a null distribution"
375+ assert self . null_distribution is not None , (
376+ "You need to provide a null distribution"
377+ )
378378 theta_p_t = self .null_distribution .sample ((self .theta_p .shape [0 ],))
379379 theta_q_t = self .null_distribution .sample ((self .theta_p .shape [0 ],))
380380 x_p_t , x_q_t = self .x_p , self .x_q
@@ -419,9 +419,9 @@ def get_statistics_under_null_hypothesis(
419419 Run `train_under_null_hypothesis`."
420420 )
421421 else :
422- assert (
423- len ( self . trained_clfs_null ) == self . num_trials_null
424- ), "You need one classifier per trial."
422+ assert len ( self . trained_clfs_null ) == self . num_trials_null , (
423+ "You need one classifier per trial."
424+ )
425425
426426 probs_null , stats_null = [], []
427427 for t in tqdm (
@@ -433,9 +433,9 @@ def get_statistics_under_null_hypothesis(
433433 if self .permutation :
434434 theta_o_t = theta_o
435435 else :
436- assert (
437- self . null_distribution is not None
438- ), "You need to provide a null distribution"
436+ assert self . null_distribution is not None , (
437+ "You need to provide a null distribution"
438+ )
439439
440440 theta_o_t = self .null_distribution .sample ((theta_o .shape [0 ],))
441441
0 commit comments