Skip to content

Commit 41ef683

Browse files
committed
some changes to deterministic trainer
1 parent c42bf38 commit 41ef683

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

makani/utils/training/deterministic_trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
179179
self.loss_obj = self.loss_obj.to(self.device)
180180
self.timers["loss handler init"] = timer.time
181181

182-
# channel weights:
183-
if self.log_to_screen:
184-
chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist()
185-
chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)}
186-
self.logger.info(f"Channel weights: {chw_output}")
182+
# # channel weights:
183+
# if self.log_to_screen:
184+
# chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist()
185+
# chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)}
186+
# self.logger.info(f"Channel weights: {chw_output}")
187187

188188
# optimizer and scheduler setup
189189
with Timer() as timer:
@@ -243,7 +243,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
243243

244244
# visualization wrapper:
245245
with Timer() as timer:
246-
plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}]
246+
plot_channel = "q50"
247+
plot_index = self.params.channel_names.index(plot_channel)
248+
print(self.params.channel_names)
249+
plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}]
247250
out_bias, out_scale = self.train_dataloader.get_output_normalization()
248251
self.visualizer = visualize.VisualizationWrapper(
249252
self.params.log_to_wandb,

0 commit comments

Comments
 (0)