@@ -318,14 +318,14 @@ def train( # noqa: C901
318318 if self .config .get ("verbose_splits" , False ):
319319 print ("\n Data split summary:" )
320320 print (
321- f" Training samples: { len (train )} ({ len (train )/ len (samples )* 100 :.1f} %)"
321+ f" Training samples: { len (train )} ({ len (train ) / len (samples ) * 100 :.1f} %)"
322322 )
323323 print (
324- f" Validation samples: { len (test )} ({ len (test )/ len (samples )* 100 :.1f} %)"
324+ f" Validation samples: { len (test )} ({ len (test ) / len (samples ) * 100 :.1f} %)"
325325 )
326326 if len (pred ) > 0 :
327327 print (
328- f" Prediction samples (no coords): { len (pred )} ({ len (pred )/ len (samples )* 100 :.1f} %)"
328+ f" Prediction samples (no coords): { len (pred )} ({ len (pred ) / len (samples ) * 100 :.1f} %)"
329329 )
330330 print (f" Total samples: { len (samples )} " )
331331 print (f" Total SNPs: { self .filtered_genotypes .shape [0 ]} " )
@@ -548,13 +548,13 @@ def train_holdout( # noqa: C901
548548 if self .config .get ("verbose_splits" , False ):
549549 print ("\n Holdout split summary:" )
550550 print (
551- f" Training samples: { len (train_indices )} ({ len (train_indices )/ len (samples )* 100 :.1f} %)"
551+ f" Training samples: { len (train_indices )} ({ len (train_indices ) / len (samples ) * 100 :.1f} %)"
552552 )
553553 print (
554- f" Validation samples: { len (test_indices )} ({ len (test_indices )/ len (samples )* 100 :.1f} %)"
554+ f" Validation samples: { len (test_indices )} ({ len (test_indices ) / len (samples ) * 100 :.1f} %)"
555555 )
556556 print (
557- f" Holdout samples: { len (holdout_idx )} ({ len (holdout_idx )/ len (samples )* 100 :.1f} %)"
557+ f" Holdout samples: { len (holdout_idx )} ({ len (holdout_idx ) / len (samples ) * 100 :.1f} %)"
558558 )
559559 print (f" Total samples: { len (samples )} " )
560560 print (f" Total SNPs: { self .filtered_genotypes .shape [0 ]} " )
@@ -739,12 +739,12 @@ def _create_model(self, input_shape):
739739 """Create neural network model. Extracted to avoid duplication."""
740740 loss_fn = None
741741 if self .config .get ("use_range_penalty" ):
742- assert (
743- self . config . get ( "species_range_shapefile" ) is not None
744- ), "species_range_shapefile must be provided if use_range_penalty is True"
745- assert (
746- self . config . get ( "resolution" ) is not None
747- ), "resolution must be provided if use_range_penalty is True"
742+ assert self . config . get ( "species_range_shapefile" ) is not None , (
743+ "species_range_shapefile must be provided if use_range_penalty is True"
744+ )
745+ assert self . config . get ( "resolution" ) is not None , (
746+ "resolution must be provided if use_range_penalty is True"
747+ )
748748
749749 mask_tensor , mask_transform = rasterize_species_range (
750750 self .config ["species_range_shapefile" ],
0 commit comments