@@ -177,10 +177,7 @@ def __init__(
177177 self .output_dim = output_dim
178178
179179 self .blocks = nn .ModuleList (
180- [
181- NormformerBlock (input_dim = hidden_dim , mlp_dim = hidden_dim , num_heads = num_heads )
182- for _ in range (num_blocks )
183- ]
180+ [NormformerBlock (input_dim = hidden_dim , mlp_dim = hidden_dim , num_heads = num_heads ) for _ in range (num_blocks )]
184181 )
185182 self .project_out = nn .Linear (hidden_dim , output_dim )
186183
@@ -336,8 +333,8 @@ class VQVAELightning(L.LightningModule):
336333
337334 def __init__ (
338335 self ,
339- optimizer_kwargs = {},
340- #scheduler_kwargs = {},
336+ optimizer_kwargs = {},
337+ # scheduler_kwargs = {},
341338 model_kwargs = {},
342339 model_type = "Transformer" ,
343340 ** kwargs ,
@@ -373,8 +370,7 @@ def __init__(
373370 self .val_mask = []
374371
375372 def configure_optimizers (self ):
376- optimizer = torch .optim .AdamW (
377- self .model .parameters (), ** self .optimizer_kwargs )
373+ optimizer = torch .optim .AdamW (self .model .parameters (), ** self .optimizer_kwargs )
378374 """
379375 if self.lr_scheduler:
380376 return {
@@ -387,7 +383,7 @@ def configure_optimizers(self):
387383 }
388384 """
389385 return optimizer
390-
386+
391387 def forward (self , x_particle , mask_particle ):
392388 x_particle_reco , vq_out = self .model (x_particle , mask = mask_particle )
393389 return x_particle_reco , vq_out
@@ -428,24 +424,22 @@ def on_train_start(self) -> None:
428424 self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict
429425 )
430426 """
427+
431428 def on_train_epoch_start (self ):
432429 logger .info (f"Epoch { self .trainer .current_epoch } starting." )
433430 self .epoch_train_start_time = time .time () # start timing the epoch
434431
435432 def on_train_epoch_end (self ):
436433 self .epoch_train_end_time = time .time ()
437- self .epoch_train_duration_minutes = (
438- self .epoch_train_end_time - self .epoch_train_start_time
439- ) / 60
434+ self .epoch_train_duration_minutes = (self .epoch_train_end_time - self .epoch_train_start_time ) / 60
440435 self .log (
441436 "epoch_train_duration_minutes" ,
442437 self .epoch_train_duration_minutes ,
443438 on_epoch = True ,
444439 prog_bar = False ,
445440 )
446441 logger .info (
447- f"Epoch { self .trainer .current_epoch } finished in"
448- f" { self .epoch_train_duration_minutes :.1f} minutes."
442+ f"Epoch { self .trainer .current_epoch } finished in" f" { self .epoch_train_duration_minutes :.1f} minutes."
449443 )
450444
451445 def on_train_end (self ):
@@ -492,9 +486,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i
492486 saveas = plot_filename ,
493487 )
494488 if comet_logger is not None :
495- comet_logger .log_image (
496- plot_filename , name = plot_filename .split ("/" )[- 1 ], step = curr_step
497- )
489+ comet_logger .log_image (plot_filename , name = plot_filename .split ("/" )[- 1 ], step = curr_step )
498490
499491 return loss
500492
@@ -569,9 +561,7 @@ def tokenize_ak_array(self, ak_arr, pp_dict, batch_size=256, pad_length=128, hid
569561 tokens = np_to_ak (codes , names = ["token" ], mask = mask )["token" ]
570562 return tokens
571563
572- def reconstruct_ak_tokens (
573- self , tokens_ak , pp_dict , batch_size = 256 , pad_length = 128 , hide_pbar = False
574- ):
564+ def reconstruct_ak_tokens (self , tokens_ak , pp_dict , batch_size = 256 , pad_length = 128 , hide_pbar = False ):
575565 """Reconstruct tokenized awkward array.
576566
577567 Parameters
@@ -635,9 +625,7 @@ def reconstruct_ak_tokens(
635625 if hasattr (self .model , "latent_projection_out" ):
636626 x_reco_batch = self .model .latent_projection_out (z_q ) * mask_batch .unsqueeze (- 1 )
637627 x_reco_batch = self .model .decoder_normformer (x_reco_batch , mask = mask_batch )
638- x_reco_batch = self .model .output_projection (
639- x_reco_batch
640- ) * mask_batch .unsqueeze (- 1 )
628+ x_reco_batch = self .model .output_projection (x_reco_batch ) * mask_batch .unsqueeze (- 1 )
641629 elif hasattr (self .model , "decoder" ):
642630 x_reco_batch = self .model .decoder (z_q )
643631 else :
@@ -665,8 +653,6 @@ def on_test_epoch_end(self):
665653 self .test_labels_concat = np .concatenate (self .test_labels )
666654 self .test_code_idx_concat = np .concatenate (self .test_code_idx )
667655
668-
669-
670656
671657def plot_model (model , samples , device = "cuda" , n_examples_to_plot = 200 , masks = None , saveas = None ):
672658 """Visualize the model.
@@ -833,22 +819,13 @@ def plot_model(model, samples, device="cuda", n_examples_to_plot=200, masks=None
833819 ax .set_yscale ("log" )
834820 print (idx )
835821 ax .set_title (
836- "Codebook histogram\n (Each entry corresponds to one sample\n being associated with that"
837- " codebook entry)" ,
822+ "Codebook histogram\n (Each entry corresponds to one sample\n being associated with that" " codebook entry)" ,
838823 fontsize = 8 ,
839824 )
840825
841826 # make empty axes invisible
842827 def is_axes_empty (ax ):
843- return not (
844- ax .lines
845- or ax .patches
846- or ax .collections
847- or ax .images
848- or ax .texts
849- or ax .artists
850- or ax .tables
851- )
828+ return not (ax .lines or ax .patches or ax .collections or ax .images or ax .texts or ax .artists or ax .tables )
852829
853830 for ax in axarr .flatten ():
854831 if is_axes_empty (ax ):
@@ -882,4 +859,4 @@ def plot_loss(loss_history, lr_history, moving_average=100):
882859 ax2 .set_ylabel ("Learning Rate" )
883860
884861 fig .tight_layout ()
885- plt .show ()
862+ plt .show ()
0 commit comments