-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvthree_utils.py
1815 lines (1547 loc) · 65.6 KB
/
vthree_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from io import StringIO
import os
from dotenv import load_dotenv
import logging
from pathlib import Path
from google.oauth2 import service_account
import climpred
import xarray as xr
import xesmf as xe
import numpy as np
import pandas as pd
import regionmask
from shapely import wkb
import geopandas as gp
from climpred import HindcastEnsemble
from datetime import datetime
#from datatree import DataTree
import dask.dataframe as daskdf
import xhistogram.xarray as xhist
from sklearn.metrics import roc_auc_score
import xskillscore as xs
from xbootstrap import block_bootstrap
from dask.distributed import Client
# matplotlib.use("Agg")
import itertools
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import six
import textwrap as tw
from functools import reduce
import json
from datetime import datetime
from dateutil.relativedelta import relativedelta
from calendar import monthrange
from PIL import Image
from google.oauth2 import service_account
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# load_dotenv()
# data_path = os.getenv("data_path")
# latex_path = os.getenv("latex_path")
class BinCreateParams:
def __init__(
self,
region_id,
season_str,
lead_int,
level,
region_name_dict,
spi_prod_name,
data_path,
spi4_data_path,
output_path,
obs_netcdf_file,
fct_netcdf_file,
service_account_json,
gcs_file_url,
region_filter
):
self.region_id = region_id
self.season_str = season_str
self.lead_int = lead_int
self.sc_season_str = season_str.lower()
self.level = level
self.region_name_dict = region_name_dict
self.spi_prod_name = spi_prod_name
self.data_path = self._ensure_trailing_slash(data_path)
self.spi4_data_path = self._ensure_trailing_slash(spi4_data_path)
self.output_path = self._ensure_trailing_slash(output_path)
self.obs_netcdf_file = obs_netcdf_file
self.fct_netcdf_file = fct_netcdf_file
self.service_account_json = service_account_json
self.gcs_file_url = gcs_file_url
self.region_filter = region_filter
# Create necessary directories
self._create_directories()
def _ensure_trailing_slash(self, path):
"""Ensure the path ends with a trailing slash."""
if not path.endswith(os.path.sep):
return os.path.join(path, '')
return path
def _create_directories(self):
"""Create necessary directories if they don't exist."""
directories = [self.output_path]
for directory in directories:
os.makedirs(directory, exist_ok=True)
print(f"Directories created/checked: {', '.join(directories)}")
def get_credentials(service_account_json):
"""Create and return Google Cloud credentials."""
credentials = service_account.Credentials.from_service_account_file(
service_account_json,
scopes=["https://www.googleapis.com/auth/devstorage.read_only"],
)
return credentials
def transform_data(data_at_time):
"""
Transforms the input data to calculate total precipitation and adjust for the number of days in each month.
Parameters:
- data_at_time (xarray.Dataset): Input dataset containing precipitation data at different forecast times.
Returns:
- data_at_time_tp (xarray.Dataset): Transformed dataset with total precipitation adjusted for the number of days in each month.
"""
valid_time = [
pd.to_datetime(data_at_time.time.values) + relativedelta(months=fcmonth - 1)
for fcmonth in data_at_time.forecastMonth
]
data_at_time = data_at_time.assign_coords(valid_time=("forecastMonth", valid_time))
numdays = [monthrange(dtat.year, dtat.month)[1] for dtat in valid_time]
data_at_time = data_at_time.assign_coords(numdays=("forecastMonth", numdays))
data_at_time_tp = data_at_time * data_at_time.numdays * 24 * 60 * 60 * 1000
data_at_time_tp.attrs["units"] = "mm"
data_at_time_tp.attrs["long_name"] = "Total precipitation"
return data_at_time_tp
def apply_spi(cont_db, lead_val, spi_name_int):
"""
Calculates given spi_name_int value Standardized Precipitation Index (SPI)
for a specified lead time.
Parameters:
- cont_db (xarray.Dataset): The input dataset containing total monthly precipitation data.
- lead_val (int): The lead time value for which the SPI is calculated.
Returns:
- cont_spi (list): A list of xarray.DataArrays containing the SPI values for each ensemble member.
"""
lt1_db = cont_db.sel(forecastMonth=lead_val)
lt1_db["tprate"].attrs["units"] = "mm/month"
cont_spi = []
for nsl in lt1_db.number.values:
lt1_db2 = lt1_db.sel(number=nsl)
# lt1_db3 = lt1_db2.chunk({'time': 4, 'latitude': 2, 'longitude': 2})
lt1_db3 = lt1_db2.chunk(-1)
aa = lt1_db3.tprate
spi_3 = standardized_precipitation_index(
aa,
freq="MS",
window=spi_name_int,
dist="gamma",
method="APP",
cal_start="1991-01-01",
cal_end="2018-01-01",
)
a_s3 = spi_3.compute()
cont_spi.append(a_s3)
aa = []
lt1_db3 = []
lt1_db2 = []
print(nsl)
return cont_spi
def apply_spii_mem(cont_db, lead_val, spi_name_int):
"""
Calculates given spi_name_int value Standardized Precipitation Index (SPI)
for a specified lead time.
Parameters:
- cont_db (xarray.Dataset): The input dataset containing total monthly precipitation data.
- lead_val (int): The lead time value for which the SPI is calculated.
Returns:
- cont_spi (list): A list of xarray.DataArrays containing the SPI values for each ensemble member.
"""
lt1_db = cont_db.sel(forecastMonth=lead_val)
lt1_db["tprate"].attrs["units"] = "mm/month"
cont_spi = []
for nsl in lt1_db.number.values:
lt1_db2 = lt1_db.sel(number=nsl)
# lt1_db3 = lt1_db2.chunk({'time': 4, 'latitude': 2, 'longitude': 2})
lt1_db3 = lt1_db2.chunk(-1)
aa = lt1_db3.tprate
spi_3 = standardized_precipitation_index(
aa,
freq="MS",
window=spi_name_int,
dist="gamma",
method="APP",
cal_start="2017-01-01",
cal_end="2023-12-01",
)
a_s3 = spi_3.compute()
cont_spi.append(a_s3)
aa = []
lt1_db3 = []
lt1_db2 = []
print(nsl)
return cont_spi
def ken_mask_creator(data_path):
"""
Utility for generating region/district masks using regionmask library
Returns
-------
the_mask : regionmask.Regions
The created mask for the regions.
rl_dict : dict
Dictionary mapping region numbers to region names.
mds2 : geopandas.GeoDataFrame
GeoDataFrame containing geometry, region, and region_name information.
"""
logger.info("Starting ken_mask_creator function")
try:
logger.info(
f"Reading Karamoja boundary file from {data_path}Karamoja_boundary_dissolved.shp"
)
dis = gp.read_file(f"{data_path}Karamoja_boundary_dissolved.shp")
logger.info(
f"Reading Wajir and Marsabit extent file from {data_path}wajir_mbt_extent.shp"
)
reg = gp.read_file(f"{data_path}wajir_mbt_extent.shp")
# Check if the geometries are valid
# if not dis.geometry.is_valid.all() or not reg.geometry.is_valid.all():
# raise ValueError("Invalid geometries found in shapefiles")
logger.info("Concatenating district and region data")
mds = pd.concat([dis, reg])
mds1 = mds.reset_index()
logger.info("Assigning region numbers and names")
mds1["region"] = [0, 1, 2]
mds1["region_name"] = ["Karamoja", "Marsabit", "Wajir"]
mds2 = mds1[["geometry", "region", "region_name"]]
# valid_types = ('Polygon', 'MultiPolygon')
# if not all(geom.geom_type in valid_types for geom in mds2.geometry):
# raise ValueError("All geometries must be Polygon or MultiPolygon")
if mds2.empty:
raise ValueError("GeoDataFrame is empty")
logger.info("Creating region-name dictionary")
rl_dict = dict(zip(mds2.region, mds2.region_name))
logger.info("Creating regionmask from GeoDataFrame")
# mds2['geometry'] = mds2['geometry'].apply(lambda x: [x])
# the_mask = regionmask.from_geopandas(mds2, numbers="region", overlap=False)
the_mask = []
logger.info("ken_mask_creator function completed successfully")
return the_mask, rl_dict, mds2
except FileNotFoundError as e:
logger.error(f"File not found: {e}")
raise
except Exception as e:
logger.error(f"An error occurred in ken_mask_creator: {e}")
raise
def get_region_bounds(region_id, credentials=None, buffer=0.5, use_local=False, local_shapefile_path='../kmj_polygon.shp'):
"""
Get geographic bounds for a specific region with buffer, supporting both local and GCS data sources.
Args:
region_id (str): The region identifier to filter by
credentials: Google Cloud credentials (required if use_local=False)
buffer (float): Buffer in degrees to add around the region extent
use_local (bool): Whether to use local shapefile (True) or GCS (False)
local_shapefile_path (str): Path to local shapefile if use_local=True
Returns:
tuple: (gdf, extent) where:
- gdf is a GeoDataFrame containing the filtered region data
- extent is [lat_min, lat_max, lon_min, lon_max]
"""
if use_local:
# Read from local shapefile
try:
gdf = gp.read_file(local_shapefile_path)
# Filter by region id (assuming column name is 'gbid', adjust if different)
if 'gbid' in gdf.columns:
gdf = gdf[gdf['gbid'].str.contains(region_id)]
else:
# If 'gbid' column doesn't exist, try to find a suitable ID column
id_columns = [col for col in gdf.columns if 'id' in col.lower()]
if id_columns:
gdf = gdf[gdf[id_columns[0]].astype(str).str.contains(region_id)]
else:
raise ValueError("No suitable ID column found in local shapefile")
if len(gdf) == 0:
raise ValueError(f"No regions found with id containing '{region_id}' in local shapefile")
except Exception as e:
raise ValueError(f"Error reading local shapefile: {str(e)}")
else:
# Read from GCS
if credentials is None:
raise ValueError("Credentials required when not using local shapefile")
gcs_file_url = 'gs://seas51/ea_admin0_2_custom_polygon_shapefile_v5.parquet'
ddf = daskdf.read_parquet(gcs_file_url, storage_options={'token': credentials}, engine='pyarrow')
# Filter by region id
fdf = ddf[ddf['gbid'].str.contains(region_id)]
df1 = fdf.compute()
if len(df1) == 0:
raise ValueError(f"No regions found with id containing '{region_id}' in GCS dataset")
# Convert geometry from WKB to shapely geometry
df1['geometry'] = df1['geometry'].apply(wkb.loads)
gdf = gp.GeoDataFrame(df1, geometry='geometry')
# Get bounds
bounds = gdf.bounds
lat_min = bounds['miny'].min() - buffer
lat_max = bounds['maxy'].max() + buffer
lon_min = bounds['minx'].min() - buffer
lon_max = bounds['maxx'].max() + buffer
extent = [lat_min, lat_max, lon_min, lon_max]
return gdf, extent
def gcs_paraquet_mask_creator(params):
"""
Utility for generating region/district masks using regionmask library,
with data sourced from Google Cloud Storage.
Parameters:
- service_account_json (str): Path to the service account key file.
- gcs_file_url (str): GCS file URL for the parquet file.
- region_filter (str): Pipe-separated string of region codes to filter (e.g., 'kmj|mbt|wjr').
Returns:
-------
the_mask : regionmask.Regions
The created mask for the regions.
rl_dict : dict
Dictionary mapping region numbers to region names.
gdf : geopandas.GeoDataFrame
GeoDataFrame containing geometry, region, and region_name information.
"""
logger.info("Starting gcs_mask_creator function")
try:
# Create credentials object
credentials = service_account.Credentials.from_service_account_file(
params.service_account_json,
scopes=["https://www.googleapis.com/auth/devstorage.read_only"],
)
# Read the parquet file from GCS
logger.info(f"Reading parquet file from {params.gcs_file_url}")
ddf = daskdf.read_parquet(params.gcs_file_url, storage_options={'token': credentials}, engine='pyarrow')
# Filter for required regions
logger.info(f"Filtering regions based on: {params.region_filter}")
fdf = ddf[ddf['gbid'].str.contains(params.region_filter, case=False)]
df = fdf.compute()
logger.info("Converting WKB to Shapely geometries")
df['geometry'] = df['geometry'].apply(wkb.loads)
logger.info("Creating GeoDataFrame")
gdf = gp.GeoDataFrame(df, geometry='geometry')
# Assuming 'gbid' is the column for region codes and there's a 'name' column for region names
# If the column names are different, please adjust accordingly
gdf = gdf.rename(columns={'gbid': 'region', 'name': 'region_name'})
if gdf.empty:
raise ValueError("GeoDataFrame is empty")
logger.info("Creating region-name dictionary")
rl_dict = dict(zip(gdf.region, gdf.region_name))
logger.info("Creating regionmask from GeoDataFrame")
#TO DO
#the_mask = regionmask.from_geopandas(gdf, numbers="region", names="region_name")
the_mask=''
logger.info("gcs_mask_creator function completed successfully")
return the_mask, rl_dict, gdf
except Exception as e:
logger.error(f"An error occurred in gcs_mask_creator: {e}")
raise
def spi3_prod_name_creator(ds_ens, var_name):
"""
Convenience function to generate a list of SPI product
names, such as MAM, so that can be used to filter the
SPI product from dataframe
added with method to convert the valid_time in CF format into datetime at
line 3, which is the format given by climpred valid_time calculation
Parameters
----------
ds_ens : xarray dataframe
The data farme with SPI output organized for
the period 1981-2023.
Returns
-------
spi_prod_list : String list
List of names with iteration of SPI3 product names such as
['JFM','FMA','MAM',......]
"""
db = pd.DataFrame()
db["dt"] = ds_ens[var_name].values
db["dt1"] = db["dt"].apply(
lambda x: datetime(x.year, x.month, x.day, x.hour, x.minute, x.second)
)
# db['dt1']=db['dt'].to_datetimeindex()
db["month"] = db["dt1"].dt.strftime("%b").astype(str).str[0]
db["year"] = db["dt1"].dt.strftime("%Y")
db["spi_prod"] = (
db.groupby("year")["month"].shift(2)
+ db.groupby("year")["month"].shift(1)
+ db.groupby("year")["month"].shift(0)
)
spi_prod_list = db["spi_prod"].tolist()
return spi_prod_list
def spi4_prod_name_creator(ds_ens, var_name):
"""
Convenience function to generate a list of SPI product
names, such as MAM, so that can be used to filter the
SPI product from dataframe
added with method to convert the valid_time in CF format into datetime at
line 3, which is the format given by climpred valid_time calculation
Parameters
----------
ds_ens : xarray dataframe
The data farme with SPI output organized for
the period 1981-2023.
Returns
-------
spi_prod_list : String list
List of names with iteration of SPI3 product names such as
['JFM','FMA','MAM',......]
"""
db = pd.DataFrame()
db["dt"] = ds_ens[var_name].values
db["dt1"] = db["dt"].apply(
lambda x: datetime(x.year, x.month, x.day, x.hour, x.minute, x.second)
)
# db['dt1']=db['dt'].to_datetimeindex()
db["month"] = db["dt1"].dt.strftime("%b").astype(str).str[0]
db["year"] = db["dt1"].dt.strftime("%Y")
db["spi_prod"] = (
db.groupby("year")["month"].shift(3)
+ db.groupby("year")["month"].shift(2)
+ db.groupby("year")["month"].shift(1)
+ db.groupby("year")["month"].shift(0)
)
spi_prod_list = db["spi_prod"].tolist()
return spi_prod_list
def make_obs_fct_dataset(params):
"""
Prepares observed and forecasted dataset subsets for a specific region, season, and lead time.
This function loads observed and forecasted datasets based on the season string length (indicating SPI3 or SPI4),
applies regional masking, selects the data for the given region by its ID, and subsets the data for the specified
season and lead time. It then aligns the observed dataset time coordinates with the forecasted dataset valid time
coordinates and returns both datasets.
Parameters:
- region_id (int): The identifier for the region of interest.
- season_str (str): A string representing the season. The length of this string determines whether SPI3 or SPI4
datasets are used ('mam', 'jjas', etc. for SPI3, and longer strings for SPI4).
- lead_int (int): The lead time index for which the forecast dataset is to be subset.
Returns:
- obs_data (xarray.DataArray): The subsetted observed data array for the specified region, season, and aligned time coordinates.
- ens_data (xarray.DataArray): The subsetted forecast data array for the specified region, season, lead time, and aligned time coordinates.
Notes:
- The function assumes the existence of a `data_path` variable that specifies the base path to the dataset files.
- It requires the `xarray` library for data manipulation and assumes specific naming conventions for the dataset files.
- Regional masking and season-specific processing rely on externally defined functions and naming conventions.
- The final alignment of observed dataset time coordinates with forecasted dataset valid time coordinates ensures
comparability between observed and forecasted values for verification purposes.
Example Usage:
>>> obs_data, ens_data = make_obs_fct_dataset(1, 'mam', 0)
>>> print(obs_data)
>>> print(ens_data)
This would load the observed and forecasted SPI3 datasets for region 1 during the 'mam' season and subset them
for lead time index 0, aligning the observed data time coordinates with the forecasted data valid time coordinates.
"""
try:
#the_mask, rl_dict, mds1 = gcs_paraquet_mask_creator(params)
#bounds = mds1.bounds
#llon, llat = bounds.iloc[params.region_id][["minx", "miny"]]
#ulon, ulat = bounds.iloc[params.region_id][["maxx", "maxy"]]
#logger.debug(
# f"Region bounds: llon={llon}, llat={llat}, ulon={ulon}, ulat={ulat}"
#)
if len(params.season_str) == 3:
kn_obs = xr.open_dataset('./kmj_obs_spi3_masked.nc')
kn_fct = xr.open_dataset('./kmj_rgr_seas51_spi3_masked.nc')
logger.info("Loaded SPI3 datasets")
else:
kn_fct = xr.open_dataset(os.path.join(params.data_path, params.fct_netcdf_file))
kn_obs = xr.open_dataset(os.path.join(params.data_path, params.obs_netcdf_file))
logger.info("Loaded SPI4 datasets")
#a_fc = kn_fct.sel(lon=slice(llon, ulon), lat=slice(llat, ulat))
#a_obs = kn_obs.sel(lon=slice(llon, ulon), lat=slice(llat, ulat))
a_obs=kn_obs
a_fc=kn_fct
logger.info("subsetted obs and fcst to given region")
logger.debug("Created HindcastEnsemble")
hindcast = HindcastEnsemble(a_fc)
hindcast = hindcast.add_observations(a_obs)
a_fc1 = hindcast.get_initialized()
logger.debug("Added climpred HindcastEnsemble to add valid_time in fcst")
a_fc2 = a_fc1.isel(lead=params.lead_int)
if len(params.season_str) == 3:
spi_prod_list = spi3_prod_name_creator(a_fc2, "valid_time")
obs_spi_prod_list = spi3_prod_name_creator(a_obs, "time")
else:
spi_prod_list = spi4_prod_name_creator(a_fc2, "valid_time")
obs_spi_prod_list = spi4_prod_name_creator(a_obs, "time")
logger.info(
f"added SPI prodcut in obs and fcst dataset, filtered to {params.season_str}"
)
a_fc2 = a_fc2.assign_coords(spi_prod=("init", spi_prod_list))
a_fc3 = a_fc2.where(a_fc2.spi_prod == params.season_str, drop=True)
a_obs1 = a_obs.assign_coords(spi_prod=("time", obs_spi_prod_list))
a_obs2 = a_obs1.where(a_obs1.spi_prod == params.season_str, drop=True)
# Convert valid_time to numpy datetime64 for comparison from cftime of a_fc3
fct_valid_times = np.array(
[np.datetime64(vt.isoformat()) for vt in a_fc3.valid_time.values]
)
obs_times = a_obs2.time.values
# Find common dates
common_dates = np.intersect1d(fct_valid_times, obs_times)
# common_dates = np.unique(a_fc3.valid_time.values.ravel())
# Filter both datasets to include only common dates
# a_fc4 = a_fc3.sel(valid_time=common_dates)
a_obs3 = a_obs2.sel(time=common_dates)
# a_obs3 = a_obs2.sel(time=common_dates)
a_fc3_init_dates = common_dates.astype("datetime64[M]") - np.timedelta64(
int(a_fc3.lead.values), "M"
)
a_fc4 = a_fc3.sel(init=a_fc3_init_dates)
# Ensure the time dimension in a_fc4 matches the valid_time coordinate
# a_fc4 = a_fc4.assign_coords(time=('valid_time', common_dates))
# a_fc4 = a_fc4.swap_dims({'valid_time': 'time'})
logger.info(
f"Found {len(common_dates)} common dates between observed and forecast data"
)
logger.info("Successfully prepared observed and forecasted datasets")
return a_obs3, a_fc4
except FileNotFoundError as e:
logger.error(f"File not found: {e}")
raise
except ValueError as e:
logger.error(f"Value error: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in make_obs_fct_dataset: {e}")
raise
# return a_obs3, a_fc3
def get_threshold(region_id, season):
"""
Retrieves the drought threshold value for a specified region, season, and drought level.
The function reads predefined threshold values from a CSV-format string. It looks up the threshold for the given
region ID, season, and drought level ('mod' for moderate, 'sev' for severe, or 'ext' for extreme). These thresholds
are specific to certain regions and seasons and indicate the level at which a drought event of a particular severity
is considered to occur.
Parameters:
- region_id (int): The integer identifier for the region of interest.
- season (str): The season for which the threshold is required. Expected values are season codes such as 'mam' (March-April-May),
'jjas' (June-July-August-September), 'ond' (October-November-December), etc.
- level (str): The drought severity level for which the threshold is requested. Valid options are 'mod' for moderate,
'sev' for severe, and 'ext' for extreme drought conditions.
Returns:
- float: The threshold value for the specified region, season, and drought level. Returns None if no threshold is found for the given inputs.
Note:
- This function uses a hardcoded CSV string as its data source. In a production environment, it's recommended to
store and retrieve such data from a more robust data management system.
- The function requires the pandas library for data manipulation and the StringIO module from io for string-based data input.
data_v1 = region_id,region,season,mod,sev,ext
0,kmj,mam,-0.03,-0.56,-0.99
0,kmj,jjas,-0.01,-0.41,-0.99
1,mbt,mam,-0.14,-0.38,-0.8
1,mbt,ond,-0.15,-0.53,-0.71
2,wjr,mam,-0.19,-0.45,-0.75
2,wjr,ond,-0.29,-0.76,-0.9
Example usage:
>>> threshold = get_threshold(1, 'mam', 'mod')
>>> print(threshold)
-0.14
"""
data = """region_id,region,season,mod,sev,ext
0,kmj,mamo,-0.55,-0.98,-0.99
0,kmj,mam,-0.43,-0.67,-0.84
0,kmj,jja,-0.43,-0.67,-0.84
0,kmj,jjas,-0.40,-0.98,-0.99
1,mbt,mam,-0.15,-0.53,-0.71
1,mbt,ond,-0.15,-0.53,-0.71
2,wjr,mam,-0.29,-0.76,-0.90
2,wjr,ond,-0.29,-0.76,-0.90
"""
# Use StringIO to convert the string data to a file-like object
data_io = StringIO(data)
# Read the data into a pandas DataFrame
df = pd.read_csv(data_io)
thresholds_dict = {
(row["region_id"], row["season"]): {
"mod": row["mod"],
"sev": row["sev"],
"ext": row["ext"],
}
for _, row in df.iterrows()
}
# Retrieve the dictionary for the given region_id and season
season_thresholds = thresholds_dict.get((region_id, season), {})
# Return the threshold for the given level (mod, sev, ext), or None if not found
return season_thresholds
def mean_obs_spi(obs_data, spi_string_name):
obs_data_mean = obs_data.mean(dim=["lat", "lon"])
obs_data_df = obs_data_mean.to_dataframe().reset_index()
obs_data_df1 = obs_data_df[["time", spi_string_name]]
wdf = obs_data_df1
wdf["year0"] = wdf["time"].apply(
lambda x: datetime(x.year, x.month, x.day, x.hour, x.minute, x.second)
)
wdf["year"] = wdf["year0"].dt.strftime("%Y")
wdf1 = wdf[[spi_string_name, "year"]]
return wdf1
def empirical_probability(ens_data, threshold_dict):
"""
Calculate empirical probabilities for moderate, severe, and extreme drought conditions.
Args:
ens_data (xarray.DataArray): Ensemble data containing drought index values.
threshold_dict (dict): Dictionary containing threshold values for moderate, severe, and extreme drought.
Returns:
tuple: Three xarray.DataArrays containing empirical probabilities for moderate, severe, and extreme drought.
Raises:
KeyError: If required keys are missing from threshold_dict.
ValueError: If ens_data is not an xarray.DataArray or doesn't have a 'member' dimension.
"""
try:
if not isinstance(ens_data, xr.Dataset):
raise ValueError("ens_data must be an xarray.Dataset")
if "member" not in ens_data.dims:
raise ValueError("ens_data must have a 'member' dimension")
for key in ["mod", "sev", "ext"]:
if key not in threshold_dict:
raise KeyError(f"threshold_dict is missing required key: {key}")
mod_thr = threshold_dict["mod"]
fct_mod = (ens_data <= mod_thr).mean(dim="member")
sev_thr = threshold_dict["sev"]
fct_sev = (ens_data <= sev_thr).mean(dim="member")
ext_thr = threshold_dict["ext"]
fct_ext = (ens_data <= ext_thr).mean(dim="member")
logger.info("Empirical probabilities calculated successfully")
return fct_mod, fct_sev, fct_ext
except Exception as e:
logger.error(f"Error in empirical_probability: {str(e)}")
raise
def seas51_patch_empirical_probability(ens_data, threshold_dict):
"""
Calculate empirical probabilities for SEAS5.1 forecast system, handling the transition from 25(1981-2017) to 51(2017-current) members.
Args:
ens_data (xarray.DataArray): Ensemble data containing drought index values.
threshold_dict (dict): Dictionary containing threshold values for moderate, severe, and extreme drought.
Returns:
tuple: Three xarray.DataArrays containing empirical probabilities for moderate, severe, and extreme drought.
Raises:
ValueError: If ens_data is not an xarray.DataArray or doesn't have required dimensions.
"""
try:
if not isinstance(ens_data, xr.Dataset):
raise ValueError("ens_data must be an xarray.DataArray")
if "init" not in ens_data.dims or "member" not in ens_data.dims:
raise ValueError("ens_data must have 'init' and 'member' dimensions")
m26_ens_data = ens_data.sel(init=slice("1981", "2016"))
m26_ens_data1 = m26_ens_data.isel(member=slice(0, 25))
m26_fct_mod, m26_fct_sev, m26_fct_ext = empirical_probability(
m26_ens_data1, threshold_dict
)
m51_ens_data = ens_data.sel(init=slice("2017", None))
m51_fct_mod, m51_fct_sev, m51_fct_ext = empirical_probability(
m51_ens_data, threshold_dict
)
fct_mod = xr.concat(
[m26_fct_mod, m51_fct_mod], dim="init", coords="minimal", compat="override"
)
fct_sev = xr.concat(
[m26_fct_sev, m51_fct_sev], dim="init", coords="minimal", compat="override"
)
fct_ext = xr.concat(
[m26_fct_ext, m51_fct_ext], dim="init", coords="minimal", compat="override"
)
logger.info("SEAS5.1 patch empirical probabilities calculated successfully")
return fct_mod, fct_sev, fct_ext
except Exception as e:
logger.error(f"Error in seas51_patch_empirical_probability: {str(e)}")
raise
def print_data_stats(obs_data, ens_prob_data, params):
try:
logger.info("Calculating and printing data statistics")
obs_stats = obs_data[params.spi_prod_name].compute()
ens_stats = ens_prob_data[params.spi_prod_name].compute()
logger.info(
f"Observed data stats: min={obs_stats.min().item():.4f}, "
f"max={obs_stats.max().item():.4f}, "
f"mean={obs_stats.mean().item():.4f}, "
f"std={obs_stats.std().item():.4f}"
)
logger.info(
f"Ensemble data stats: min={ens_stats.min().item():.4f}, "
f"max={ens_stats.max().item():.4f}, "
f"mean={ens_stats.mean().item():.4f}, "
f"std={ens_stats.std().item():.4f}"
)
except Exception as e:
logger.error(f"Error in print_data_stats: {str(e)}")
raise
def print_histogram(ens_prob_data, params):
try:
logger.info("Calculating and printing histogram")
ens_prob_data_np = ens_prob_data[params.spi_prod_name].values.flatten()
hist, bin_edges = np.histogram(ens_prob_data_np, bins=20, range=(0, 1))
logger.info("Histogram of drought forecast probabilities:")
for i, count in enumerate(hist):
logger.info(f" {bin_edges[i]:.2f} - {bin_edges[i+1]:.2f}: {count}")
except Exception as e:
logger.error(f"Error in print_histogram: {str(e)}")
raise
def calculate_contingency_table(obs_event, forecast_event, params):
try:
logger.info("Calculating contingency table")
obs_event1 = obs_event[params.spi_prod_name]
obs_event1.name = "observed_event"
forecast_event1 = forecast_event[params.spi_prod_name]
forecast_event1.name = "forecasted_event"
forecast_event3 = forecast_event1.assign_coords(
init=forecast_event1.coords["valid_time"]
)
# Step 2: Rename 'valid_time' to 'time'
forecast_event3 = forecast_event3.rename({"init": "time"})
time_strings = [str(t) for t in forecast_event3["time"].values]
# Step 2: Convert the strings to numpy.datetime64 in ISO 8601 format
time_np64 = np.array(time_strings, dtype="datetime64[ns]")
# Step 3: Update the 'time' coordinate with numpy.datetime64 values
forecast_event3 = forecast_event3.assign_coords(time=time_np64)
contingency_table = xhist.histogram(
obs_event1, forecast_event3, bins=[2, 2], density=False, dim=["lat", "lon"]
)
logger.info(f"Contingency table shape: {contingency_table.shape}")
logger.info(f"Contingency table contents:\n{contingency_table}")
return contingency_table
except Exception as e:
logger.error(f"Error in calculate_contingency_table: {str(e)}")
raise
def calculate_scores(contingency_table, trigger_value, params):
try:
logger.info(f"Calculating scores for trigger value {trigger_value:.4f}")
# Ensure the contingency table has the expected dimensions
if contingency_table.ndim != 3 or contingency_table.shape[1:] != (2, 2):
raise ValueError(
f"Unexpected contingency table shape: {contingency_table.shape}"
)
# Calculate scores for each time step
time_steps = contingency_table.shape[0]
scores = []
for t in range(time_steps):
ct = contingency_table[t].values.flatten()
correct_negatives, false_alarms, misses, hits = ct
total = hits + false_alarms + misses + correct_negatives
hit_rate = hits / (hits + misses) if (hits + misses) > 0 else np.nan
false_alarm_ratio = (
false_alarms / (false_alarms + hits)
if (false_alarms + hits) > 0
else np.nan
)
bias_score = (
(hits + false_alarms) / (hits + misses)
if (hits + misses) > 0
else np.nan
)
hanssen_kuipers_score = hit_rate - (
false_alarms / (false_alarms + correct_negatives)
if (false_alarms + correct_negatives) > 0
else np.nan
)
heidke_skill_score = (
(hits * correct_negatives - misses * false_alarms)
/ (
(hits + misses) * (misses + correct_negatives)
+ (hits + false_alarms) * (false_alarms + correct_negatives)
)
if total > 0
else np.nan
)
scores.append(
{
"x2d_region": params.region_id,
"x2d_leadtime": params.lead_int,
"x2d_season": params.sc_season_str,
"x2d_level": params.level,
"trigger_value": trigger_value,
"time_step": t,
"hit_rate": hit_rate,
"false_alarm_ratio": false_alarm_ratio,
"bias_score": bias_score,
"hanssen_kuipers_score": hanssen_kuipers_score,
"heidke_skill_score": heidke_skill_score,
"total": total,
}
)
logger.info(f"Scores calculated for {time_steps} time steps")
return scores
except Exception as e:
logger.error(f"Error in calculate_scores: {str(e)}")
raise
def calculate_auroc(hits, misses, false_alarms, correct_negatives):
"""
Calculates the Area Under the Receiver Operating Characteristic (AUROC) curve for a set of forecasts relative to observations.
This function computes the AUROC score as a measure of the forecast's ability to discriminate between two classes:
events that occurred (drought) and events that did not occur (no drought). The AUROC score ranges from 0 to 1,
where a score of 0.5 suggests no discriminative ability (equivalent to random chance), and a score of 1 indicates perfect discrimination.
Parameters:
- hits (int): The number of correctly forecasted events (true positives).
- misses (int): The number of events that were observed but not forecasted (false negatives).
- false_alarms (int): The number of non-events that were incorrectly forecasted as events (false positives).
- correct_negatives (int): The number of non-events that were correctly forecasted (true negatives).
Returns:
- auroc (float): The calculated AUROC score for the given contingency table values.
Note:
- This function is designed to work with binary classification problems, such as predicting the occurrence or non-occurrence of drought events.
- It requires the `roc_auc_score` function from the `sklearn.metrics` module and `numpy` for handling arrays.
Example usage:
>>> auroc_score = calculate_auroc(50, 30, 20, 100)
>>> print(f"AUROC Score: {auroc_score}")
"""
total_positives = hits + misses
total_negatives = correct_negatives + false_alarms
y_true = np.concatenate((np.ones(total_positives), np.zeros(total_negatives)))
y_scores = np.concatenate(
(np.ones(hits), np.zeros(misses + false_alarms + correct_negatives))
)
auroc = roc_auc_score(y_true, y_scores)
return auroc
def calculate_auroc_score(contingency_table, trigger_value, n_bootstrap=1000):
try:
logger.info(
f"Calculating AUROC score for trigger value {trigger_value:.4f} with {n_bootstrap} bootstrap iterations"
)
# Ensure the contingency table has the expected dimensions
if contingency_table.ndim != 3 or contingency_table.shape[1:] != (2, 2):
raise ValueError(
f"Unexpected contingency table shape: {contingency_table.shape}"
)
# Sum the contingency table across all time steps
summed_table = contingency_table.sum(dim="time")
# Flatten the summed contingency table
ct = summed_table.values.flatten()
correct_negatives, false_alarms, misses, hits = ct
total = hits + false_alarms + misses + correct_negatives
auroc_bootstrap_scores = []
for _ in range(n_bootstrap):
bootstrap_counts = np.random.multinomial(
total,
[
hits / total,
misses / total,
false_alarms / total,
correct_negatives / total,
],
size=1,
)
(
bootstrap_hits,
bootstrap_misses,
bootstrap_false_alarms,
bootstrap_correct_negatives,
) = bootstrap_counts[0]
auroc_bootstrap_scores.append(
calculate_auroc(
bootstrap_hits,
bootstrap_misses,
bootstrap_false_alarms,
bootstrap_correct_negatives,
)
)
auroc_score = np.mean(auroc_bootstrap_scores)