|
1 | | -import time |
2 | 1 | import numpy as np |
3 | 2 | import pandas as pd |
4 | 3 | import sys |
5 | | -import json |
6 | 4 | import matplotlib.pyplot as plt |
7 | 5 | from models.physics.cfe import CFE |
8 | 6 | import torch |
9 | 7 | from torch import Tensor |
10 | | -import torch.nn as nn |
11 | 8 |
|
12 | 9 | torch.set_default_dtype(torch.float64) |
13 | 10 |
|
@@ -108,6 +105,16 @@ def __init__( |
108 | 105 | # None |
109 | 106 |
|
110 | 107 | def load_cfe_params(self): |
| 108 | + for param in self.cfe_params.values(): |
| 109 | + if torch.is_tensor(param): |
| 110 | + if param.grad is not None: |
| 111 | + param.grad = None |
| 112 | + |
| 113 | + for param in self.cfe_params["soil_params"].values(): |
| 114 | + if torch.is_tensor(param): |
| 115 | + if param.grad is not None: |
| 116 | + param.grad = None |
| 117 | + |
111 | 118 | # GET VALUES FROM Data class. |
112 | 119 |
|
113 | 120 | # Catchment area |
@@ -188,12 +195,6 @@ def initialize(self, current_time_step=0): |
188 | 195 | # Set these values now that we have the information from the configuration file. |
189 | 196 | self.num_giuh_ordinates = self.giuh_ordinates.size(1) |
190 | 197 | self.num_lateral_flow_nash_reservoirs = self.nash_storage.size(1) |
191 | | - # ________________________________________________ |
192 | | - # ----------- The output is area normalized, this is needed to un-normalize it |
193 | | - # mm->m km2 -> m2 hour->s |
194 | | - self.output_factor_cms = ( |
195 | | - (1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600) |
196 | | - ) |
197 | 198 |
|
198 | 199 | # ________________________________________________ |
199 | 200 | # The configuration should let the BMI know what mode to run in (framework vs standalone) |
@@ -259,10 +260,10 @@ def reset_flux_and_states(self): |
259 | 260 | self.gw_reservoir_storage_deficit_m = torch.zeros( |
260 | 261 | (1, self.num_basins), dtype=torch.float64 |
261 | 262 | ) # the available space in the conceptual groundwater reservoir |
262 | | - self.primary_flux = torch.zeros( |
| 263 | + self.primary_flux_m = torch.zeros( |
263 | 264 | (1, self.num_basins), dtype=torch.float64 |
264 | 265 | ) # temporary vars. |
265 | | - self.secondary_flux = torch.zeros( |
| 266 | + self.secondary_flux_m = torch.zeros( |
266 | 267 | (1, self.num_basins), dtype=torch.float64 |
267 | 268 | ) # temporary vars. |
268 | 269 | self.primary_flux_from_gw_m = torch.zeros( |
@@ -293,6 +294,13 @@ def reset_flux_and_states(self): |
293 | 294 | (1, self.num_basins), dtype=torch.float64 |
294 | 295 | ) |
295 | 296 |
|
| 297 | + # ________________________________________________ |
| 298 | + # ----------- The output is area normalized, this is needed to un-normalize it |
| 299 | + # mm->m km2 -> m2 hour->s |
| 300 | + self.output_factor_cms = ( |
| 301 | + (1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600) |
| 302 | + ) |
| 303 | + |
296 | 304 | # ________________________________________________ |
297 | 305 | # ________________________________________________ |
298 | 306 | # SOIL RESERVOIR CONFIGURATION |
@@ -366,6 +374,8 @@ def reset_flux_and_states(self): |
366 | 374 | self.volstart = self.volstart.add(self.gw_reservoir["storage_m"]) |
367 | 375 | self.vol_in_gw_start = self.gw_reservoir["storage_m"] |
368 | 376 |
|
| 377 | + # TODO: update soil parameter |
| 378 | + |
369 | 379 | self.soil_reservoir = { |
370 | 380 | "is_exponential": False, |
371 | 381 | "wilting_point_m": self.soil_params["wltsmc"] * self.soil_params["D"], |
@@ -404,6 +414,15 @@ def reset_flux_and_states(self): |
404 | 414 | self.giuh_ordinates.shape[0], self.num_giuh_ordinates + 1 |
405 | 415 | ) |
406 | 416 |
|
| 417 | + # __________________________________________________________ |
| 418 | + self.surface_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 419 | + self.streamflow_cmh = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 420 | + self.flux_nash_lateral_runoff_m = torch.zeros( |
| 421 | + (1, self.num_basins), dtype=torch.float64 |
| 422 | + ) |
| 423 | + self.flux_giuh_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 424 | + self.flux_Qout_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 425 | + |
407 | 426 | def update_params(self, refkdt, satdk): |
408 | 427 | """Update dynamic parameters""" |
409 | 428 | self.refkdt = refkdt.unsqueeze(dim=0) |
@@ -492,6 +511,9 @@ def reset_volume_tracking(self): |
492 | 511 | self.vol_et_from_soil = torch.zeros((1, self.num_basins), dtype=torch.float64) |
493 | 512 | self.vol_et_from_rain = torch.zeros((1, self.num_basins), dtype=torch.float64) |
494 | 513 | self.vol_PET = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 514 | + |
| 515 | + self.vol_in_gw_start = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 516 | + |
495 | 517 | return |
496 | 518 |
|
497 | 519 | # ________________________________________________________ |
|
0 commit comments