@@ -8,26 +8,38 @@ The configuration file ``config.yml`` and data for this example can be found `he
88
99.. code-block :: python
1010
11- # Import hydroecolstm function
12- from hydroecolstm.model_run import run_train
11+ from hydroecolstm.model_run import run_config
1312 from hydroecolstm.utility.plot import plot
1413 from hydroecolstm.data.read_config import read_config
14+ from hydroecolstm.model.create_model import create_model
1515 from hydroecolstm.utility.evaluation_function import EvaluationFunction
1616 from hydroecolstm.data.read_data import read_forecast_data
1717 import matplotlib.pyplot as plt
1818 from pathlib import Path
1919 import torch
20-
21- # Read configuration file
22- # Please modify the path to the config.yml and link to data
23- config = read_config(" C:/example/1_streamflow_simulation/config.yml" )
24-
25- # Create model and train
26- model, x_scaler, y_scaler, data = run_train(config)
27-
28- # Plot training and validation losses
29- data[" trainer" ].loss.drop([' epoch' , ' best_model' ], axis = 1 ).plot()
30-
20+
21+ # -----------------------------------------------------------------------------#
22+ # Set up, train, test model #
23+ # -----------------------------------------------------------------------------#
24+
25+ # Read configuration file, please modify the path to the config.yml file
26+ config = read_config(" C:/hydroecolstm/examples/1_streamflow_simulation/config.yml" )
27+
28+ # Create model and train from config
29+ model, data, best_config = run_config(config)
30+
31+ # Evaluate the model and transform to normal scale
32+ data[' y_train_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_train_scale" ]))
33+ data[' y_valid_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_valid_scale" ]))
34+ data[' y_test_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_test_scale" ]))
35+
36+ # Plot train and validation loss with epoch
37+ data[" loss_epoch" ].plot()
38+ plt.show()
39+
40+ # Want to see all keys in data
41+ data.keys()
42+
3143 # Objective function values: MAE, NSE, RMSE, MSE
3244 objective = EvaluationFunction(config[' eval_function' ], config[' warmup_length' ])
3345 objective(data[' y_train' ], data[' y_train_simulated' ])
@@ -39,18 +51,18 @@ The configuration file ``config.yml`` and data for this example can be found `he
3951 for target in config[" target_features" ]:
4052 p = plot(data, object_id = str (object_id), target_feature = target)
4153 p.show()
42-
54+
4355 # Application of the model for (assumed) ungagued basins
4456 forecast_dataset = read_forecast_data(config)
45- x_forecast_scale = x_scaler.transform(forecast_dataset[" x_forecast" ])
57+ x_forecast_scale = data[ " x_scaler" ] .transform(forecast_dataset[" x_forecast" ])
4658 y_forecast_scale = model.evaluate(x_forecast_scale)
4759
4860 # Application of the model for (assumed) ungagued basins
4961 forecast_dataset = read_forecast_data(config)
50- x_forecast_scale = x_scaler.transform(forecast_dataset[" x_forecast" ])
62+ x_forecast_scale = data[ " x_scaler" ] .transform(forecast_dataset[" x_forecast" ])
5163 y_forecast_scale = model.evaluate(x_forecast_scale)
52- y_forecast_simulated = y_scaler.inverse(y_forecast_scale)
53-
64+ y_forecast_simulated = data[ " y_scaler" ] .inverse(y_forecast_scale)
65+
5466 # Visualize result: train_test_period = "train" or "test"
5567 for object_id in y_forecast_simulated.keys():
5668 plt.plot(forecast_dataset[" time_forecast" ][object_id],
@@ -61,14 +73,27 @@ The configuration file ``config.yml`` and data for this example can be found `he
6173 color = ' red' , label = " Simulated" , alpha = 0.9 , linewidth = 0.75 )
6274 plt.legend()
6375 plt.show()
64-
76+
6577 # Objective function for forecast
6678 objective(forecast_dataset[' y_forecast' ], y_forecast_simulated)
67-
79+
6880 # Save all data and model state dicts to the output_directory
6981 torch.save(data, Path(config[" output_directory" ][0 ], " data.pt" ))
70- # torch.save(model.state_dict(), Path(config["output_directory"][0], "model.pt"))
71-
82+ torch.save(model.state_dict(),
83+ Path(config[" output_directory" ][0 ], " model_state_dict.pt" ))
84+
85+ # -----------------------------------------------------------------------------#
86+ # Incase you close this file and open again, #
87+ # you can load your data, model as follows #
88+ # -----------------------------------------------------------------------------#
89+ config = read_config(" C:/hydroecolstm/examples/1_streamflow_simulation/config.yml" )
90+
91+ model = create_model(config)
92+ model.load_state_dict(torch.load(Path(config[" output_directory" ][0 ],
93+ " model_state_dict.pt" )))
94+
95+ data = torch.load(Path(config[" output_directory" ][0 ], " data.pt" ))
96+
7297
7398 Multiple outputs simulation
7499---------------------------
@@ -78,25 +103,38 @@ The configuration file ``config.yml`` and data for this example can be found `he
78103.. code-block :: python
79104
80105 # Import hydroecolstm function
81- from hydroecolstm.model_run import run_train
106+
107+ from hydroecolstm.model_run import run_config
82108 from hydroecolstm.utility.plot import plot
83109 from hydroecolstm.data.read_config import read_config
110+ from hydroecolstm.model.create_model import create_model
84111 from hydroecolstm.utility.evaluation_function import EvaluationFunction
85- from hydroecolstm.data.read_data import read_forecast_data
86112 import matplotlib.pyplot as plt
87113 from pathlib import Path
88114 import torch
89-
90- # Read configuration file
91- # Please modify the path to the config.yml and link to data
92- config = read_config(" C:/example/2_streamflow_isotope_simulation/config.yml" )
93-
94- # Create model and train
95- model, x_scaler, y_scaler, data = run_train(config)
96-
97- # Plot training and validation losses
98- data[" trainer" ].loss.drop([' epoch' , ' best_model' ], axis = 1 ).plot()
99-
115+
116+ # -----------------------------------------------------------------------------#
117+ # Set up, train, test model #
118+ # -----------------------------------------------------------------------------#
119+
120+ # Read configuration file, please modify the path to the config.yml file
121+ config = read_config(" C:/hydroecolstm/examples/2_streamflow_isotope_simulation/config.yml" )
122+
123+ # Create model and train from config
124+ model, data, best_config = run_config(config)
125+
126+ # Evaluate the model and transform to normal scale
127+ data[' y_train_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_train_scale" ]))
128+ data[' y_valid_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_valid_scale" ]))
129+ data[' y_test_simulated' ] = data[" y_scaler" ].inverse(model.evaluate(data[" x_test_scale" ]))
130+
131+ # Plot train and validation loss with epoch
132+ data[" loss_epoch" ].plot()
133+ plt.show()
134+
135+ # Want to see all keys in data
136+ data.keys()
137+
100138 # Objective function values: MAE, NSE, RMSE, MSE
101139 objective = EvaluationFunction(config[' eval_function' ], config[' warmup_length' ])
102140 objective(data[' y_train' ], data[' y_train_simulated' ])
@@ -108,7 +146,19 @@ The configuration file ``config.yml`` and data for this example can be found `he
108146 for target in config[" target_features" ]:
109147 p = plot(data, object_id = str (object_id), target_feature = target)
110148 p.show()
111-
149+
112150 # Save all data and model state dicts to the output_directory
113151 torch.save(data, Path(config[" output_directory" ][0 ], " data.pt" ))
114- # torch.save(model.state_dict(), Path(config["output_directory"][0], "model.pt"))
152+ torch.save(model.state_dict(),
153+ Path(config[" output_directory" ][0 ], " model_state_dict.pt" ))
154+
155+ # -----------------------------------------------------------------------------#
156+ # Incase you close this file and open again, #
157+ # you can load your data, model as follows #
158+ # -----------------------------------------------------------------------------#
159+ config = read_config(" C:/hydroecolstm/examples/2_streamflow_isotope_simulation/config.yml" )
160+ model = create_model(config)
161+ model.load_state_dict(torch.load(Path(config[" output_directory" ][0 ],
162+ " model_state_dict.pt" )))
163+ data = torch.load(Path(config[" output_directory" ][0 ], " data.pt" ))
164+
0 commit comments