@@ -179,11 +179,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
179179 self .loss_obj = self .loss_obj .to (self .device )
180180 self .timers ["loss handler init" ] = timer .time
181181
182- # channel weights:
183- if self .log_to_screen :
184- chw_weights = self .loss_obj .channel_weights .squeeze ().cpu ().numpy ().tolist ()
185- chw_output = {k : v for k ,v in zip (self .params .channel_names , chw_weights )}
186- self .logger .info (f"Channel weights: { chw_output } " )
182+ # # channel weights:
183+ # if self.log_to_screen:
184+ # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist()
185+ # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)}
186+ # self.logger.info(f"Channel weights: {chw_output}")
187187
188188 # optimizer and scheduler setup
189189 with Timer () as timer :
@@ -243,7 +243,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
243243
244244 # visualization wrapper:
245245 with Timer () as timer :
246- plot_list = [{"name" : "windspeed_uv10" , "functor" : "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))" , "diverging" : False }]
246+ plot_channel = "q50"
247+ plot_index = self .params .channel_names .index (plot_channel )
248+ print (self .params .channel_names )
249+ plot_list = [{"name" : plot_channel , "functor" : f"lambda x: x[{ plot_index } , ...]" , "diverging" : False }]
247250 out_bias , out_scale = self .train_dataloader .get_output_normalization ()
248251 self .visualizer = visualize .VisualizationWrapper (
249252 self .params .log_to_wandb ,
0 commit comments