Skip to content

Commit 1ee572f

Browse files
committed
fix: abstracted catalog forecast repository from IO operations
1 parent 71f1ea9 commit 1ee572f

File tree

10 files changed

+204
-39
lines changed

10 files changed

+204
-39
lines changed

floatcsep/infrastructure/registries.py

-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
log = logging.getLogger("floatLogger")
1515

16-
1716
class FilepathMixin:
1817
"""
1918
Small mixin to provide filepath management functionality to Registries that uses files to
@@ -147,7 +146,6 @@ def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry":
147146
elif registry_type == 'hdf5':
148147
return ModelHDF5Registry(**kwargs)
149148

150-
151149
class ModelFileRegistry(ModelRegistry, FilepathMixin):
152150
def __init__(
153151
self,
@@ -318,7 +316,6 @@ def as_dict(self) -> dict:
318316
"forecasts": self.forecasts,
319317
}
320318

321-
322319
class ModelHDF5Registry(ModelRegistry):
323320

324321
def __init__(self, workdir: str, path: str):

floatcsep/infrastructure/repositories.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from csep.models import EvaluationResult
1313
from csep.utils.time_utils import decimal_year
1414

15-
from floatcsep.utils.readers import ForecastParsers
15+
from floatcsep.utils.readers import GriddedForecastParsers, CatalogForecastParsers
1616
from floatcsep.infrastructure.registries import ExperimentRegistry, ModelRegistry
1717
from floatcsep.utils.helpers import str2timewindow, parse_csep_func
1818
from floatcsep.utils.helpers import timewindow2str
@@ -102,28 +102,39 @@ def __init__(self, registry: ModelRegistry, **kwargs):
102102
self.forecasts = {}
103103

104104
def load_forecast(
105-
self, tstring: Union[str, list], region=None
105+
self, tstring: Union[str, list], region=None, n_sims=None,
106106
) -> Union[CatalogForecast, list[CatalogForecast]]:
107107
"""
108108
Returns a forecast object or a sequence of them for a set of time window strings.
109109
110110
Args:
111111
tstring (str, list): String representing the time-window
112112
region (optional): A region, in case the forecast requires to be filtered lazily.
113+
n_sims (optional: The number of simulations/synthetic catalogs of the forecast
113114
114115
Returns:
115116
The CSEP CatalogForecast object or a list of them.
116117
"""
117118
if isinstance(tstring, str):
118-
return self._load_single_forecast(tstring, region)
119+
return self._load_single_forecast(tstring, region=region, n_sims=n_sims)
119120
else:
120121
return [self._load_single_forecast(t, region) for t in tstring]
121122

122-
def _load_single_forecast(self, t: str, region=None):
123-
fc_path = self.registry.get_forecast_key(t)
124-
return csep.load_catalog_forecast(
125-
fc_path, region=region, apply_filters=True, filter_spatial=True
126-
)
123+
def _load_single_forecast(self, tstring: str, region=None, n_sims=None):
124+
start_date, end_date = str2timewindow(tstring)
125+
126+
fc_path = self.registry.get_forecast_key(tstring)
127+
f_parser = getattr(CatalogForecastParsers, self.registry.fmt)
128+
129+
forecast_ = f_parser(fc_path,
130+
start_time=start_date,
131+
end_time=end_date,
132+
n_cat=n_sims,
133+
region=region,
134+
apply_filters=True,
135+
filter_spatial=True,
136+
)
137+
return forecast_
127138

128139
def remove(self, tstring: Union[str, Sequence[str]]):
129140
pass
@@ -190,7 +201,7 @@ def _load_single_forecast(self, tstring: str, fc_unit: float = 1, name_=""):
190201
tstring_ = timewindow2str([start_date, end_date])
191202

192203
f_path = self.registry.get_forecast_key(tstring_)
193-
f_parser = getattr(ForecastParsers, self.registry.fmt)
204+
f_parser = getattr(GriddedForecastParsers, self.registry.fmt)
194205

195206
rates, region, mags = f_parser(f_path)
196207

floatcsep/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from floatcsep.utils.accessors import from_zenodo, from_git
1313
from floatcsep.infrastructure.environments import EnvironmentFactory
14-
from floatcsep.utils.readers import ForecastParsers, HDF5Serializer
14+
from floatcsep.utils.readers import GriddedForecastParsers, HDF5Serializer
1515
from floatcsep.infrastructure.registries import ModelRegistry
1616
from floatcsep.infrastructure.repositories import ForecastRepository
1717
from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts
@@ -247,7 +247,7 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None:
247247
exists
248248
"""
249249

250-
parser = getattr(ForecastParsers, self.registry.fmt)
250+
parser = getattr(GriddedForecastParsers, self.registry.fmt)
251251
rates, region, mag = parser(self.registry.get_attr("path"))
252252
db_func = HDF5Serializer.grid2hdf5
253253

floatcsep/utils/helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def _getattr(obj_, attr_):
7575
floatcsep.utils.helpers,
7676
floatcsep.utils.accessors,
7777
floatcsep.utils.readers.HDF5Serializer,
78-
floatcsep.utils.readers.ForecastParsers,
78+
floatcsep.utils.readers.GriddedForecastParsers,
79+
floatcsep.utils.readers.CatalogForecastParsers,
80+
7981
]
8082
for module in _target_modules:
8183
try:

floatcsep/utils/readers.py

+160-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,173 @@
11
import argparse
2+
import csv
23
import logging
34
import os.path
45
import time
56
import xml.etree.ElementTree as eTree
67

8+
import csep
79
import h5py
810
import numpy
911
import pandas
12+
import pandas as pd
13+
from csep.core.catalogs import CSEPCatalog
1014
from csep.core.regions import QuadtreeGrid2D, CartesianGrid2D
1115
from csep.models import Polygon
16+
from csep.utils.time_utils import strptime_to_utc_epoch
1217

1318
log = logging.getLogger(__name__)
1419

20+
class CatalogForecastParsers:
1521

16-
class ForecastParsers:
22+
@staticmethod
23+
def csv(filename, **kwargs):
24+
csep_headers = ['lon', 'lat', 'magnitude', 'time_string', 'depth', 'catalog_id',
25+
'event_id']
26+
hermes_headers = ['realization_id', 'magnitude', 'depth', 'latitude', 'longitude',
27+
'time']
28+
headers_df = pd.read_csv(filename, nrows=0).columns.str.strip().to_list()
29+
30+
# CSEP headers
31+
if headers_df[:2] == csep_headers[:2]:
32+
33+
return csep.load_catalog_forecast(filename, **kwargs)
34+
35+
elif headers_df == hermes_headers:
36+
return csep.load_catalog_forecast(filename,
37+
catalog_loader=CatalogForecastParsers.load_hermes_catalog,
38+
**kwargs
39+
)
40+
else:
41+
raise Exception('Catalog Forecast could not be loaded')
42+
43+
@staticmethod
44+
def load_hermes_catalog(filename, **kwargs):
45+
""" Loads hermes synthetic catalogs in csep-ascii format.
46+
47+
This function can load multiple catalogs stored in a single file. This typically called to
48+
load a catalog-based forecast, but could also load a collection of catalogs stored in the same file
49+
50+
Args:
51+
filename (str): filepath or directory of catalog files
52+
**kwargs (dict): passed to class constructor
53+
54+
Return:
55+
yields CSEPCatalog class
56+
"""
57+
58+
def read_float(val):
59+
"""Returns val as float or None if unable"""
60+
try:
61+
val = float(val)
62+
except:
63+
val = None
64+
return val
65+
66+
def is_header_line(line):
67+
if line[0].lower() == 'realization_id':
68+
return True
69+
else:
70+
return False
71+
72+
def read_catalog_line(line):
73+
# convert to correct types
74+
75+
catalog_id = int(line[0])
76+
magnitude = read_float(line[1])
77+
depth = read_float(line[2])
78+
lat = read_float(line[3])
79+
lon = read_float(line[4])
80+
# maybe fractional seconds are not included
81+
origin_time = line[5]
82+
if origin_time:
83+
try:
84+
origin_time = strptime_to_utc_epoch(origin_time,
85+
format='%Y-%m-%d %H:%M:%S.%f')
86+
except ValueError:
87+
origin_time = strptime_to_utc_epoch(origin_time,
88+
format='%Y-%m-%d %H:%M:%S')
89+
90+
event_id = 0
91+
# temporary event
92+
temp_event = (event_id, origin_time, lat, lon, depth, magnitude)
93+
return temp_event, catalog_id
94+
95+
# handle all catalogs in single file
96+
if os.path.isfile(filename):
97+
with open(filename, 'r', newline='') as input_file:
98+
catalog_reader = csv.reader(input_file, delimiter=',')
99+
# csv treats everything as a string convert to correct types
100+
events = []
101+
# all catalogs should start at zero
102+
prev_id = None
103+
for line in catalog_reader:
104+
# skip header line on first read if included in file
105+
if prev_id is None:
106+
if is_header_line(line):
107+
continue
108+
# read line and return catalog id
109+
temp_event, catalog_id = read_catalog_line(line)
110+
empty = False
111+
# OK if event_id is empty
112+
if all([val in (None, '') for val in temp_event[1:]]):
113+
empty = True
114+
# first event is when prev_id is none, catalog_id should always start at zero
115+
if prev_id is None:
116+
prev_id = 0
117+
# if the first catalog doesn't start at zero
118+
if catalog_id != prev_id:
119+
if not empty:
120+
events = [temp_event]
121+
else:
122+
events = []
123+
for id in range(catalog_id):
124+
yield CSEPCatalog(data=[], catalog_id=id, **kwargs)
125+
prev_id = catalog_id
126+
continue
127+
# accumulate event if catalog_id is the same as previous event
128+
if catalog_id == prev_id:
129+
if not all([val in (None, '') for val in temp_event]):
130+
events.append(temp_event)
131+
prev_id = catalog_id
132+
# create and yield class if the events are from different catalogs
133+
elif catalog_id == prev_id + 1:
134+
yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs)
135+
# add event to new event list
136+
if not empty:
137+
events = [temp_event]
138+
else:
139+
events = []
140+
prev_id = catalog_id
141+
# this implies there are empty catalogs, because they are not listed in the ascii file
142+
elif catalog_id > prev_id + 1:
143+
yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs)
144+
# if prev_id = 0 and catalog_id = 2, then we skipped one catalog. thus, we skip catalog_id - prev_id - 1 catalogs
145+
num_empty_catalogs = catalog_id - prev_id - 1
146+
# first yield empty catalog classes
147+
for id in range(num_empty_catalogs):
148+
yield CSEPCatalog(data=[],
149+
catalog_id=catalog_id - num_empty_catalogs + id,
150+
**kwargs)
151+
prev_id = catalog_id
152+
# add event to new event list
153+
if not empty:
154+
events = [temp_event]
155+
else:
156+
events = []
157+
else:
158+
raise ValueError(
159+
"catalog_id should be monotonically increasing and events should be ordered by catalog_id")
160+
# yield final catalog, note: since this is just loading catalogs, it has no idea how many should be there
161+
cat = CSEPCatalog(data=events, catalog_id=prev_id, **kwargs)
162+
yield cat
163+
164+
elif os.path.isdir(filename):
165+
raise NotImplementedError(
166+
"reading from directory or batched files not implemented yet!")
167+
168+
169+
170+
class GriddedForecastParsers:
17171

18172
@staticmethod
19173
def dat(filename):
@@ -151,7 +305,7 @@ def is_mag(num):
151305
sep = " "
152306

153307
if "tile" in line:
154-
rates, region, magnitudes = ForecastParsers.quadtree(filename)
308+
rates, region, magnitudes = GriddedForecastParsers.quadtree(filename)
155309
return rates, region, magnitudes
156310

157311
data = pandas.read_csv(
@@ -308,13 +462,13 @@ def serialize():
308462
args = parser.parse_args()
309463

310464
if args.format == "quadtree":
311-
ForecastParsers.quadtree(args.filename)
465+
GriddedForecastParsers.quadtree(args.filename)
312466
if args.format == "dat":
313-
ForecastParsers.dat(args.filename)
467+
GriddedForecastParsers.dat(args.filename)
314468
if args.format == "csep" or args.format == "csv":
315-
ForecastParsers.csv(args.filename)
469+
GriddedForecastParsers.csv(args.filename)
316470
if args.format == "xml":
317-
ForecastParsers.xml(args.filename)
471+
GriddedForecastParsers.xml(args.filename)
318472

319473

320474
if __name__ == "__main__":
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
lon, lat, M, time_string, depth, catalog_id, event_id
1+
lon,lat,M,time_string,depth,catalog_id,event_id
22
1.0,1.0,5.0,2020-01-01T01:01:01.0,10.0,1,1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
lon, lat, M, time_string, depth, catalog_id, event_id
1+
lon,lat,M,time_string,depth,catalog_id,event_id
22
1.0,1.0,5.0,2020-01-02T01:01:01.0,10.0,1,1

tests/integration/test_model_infrastructure.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def test_time_independent_model_stage(self):
3939
[datetime(2023, 1, 1), datetime(2023, 1, 2)],
4040
]
4141
self.time_independent_model.stage(time_windows=time_windows)
42-
print("a", self.time_independent_model.registry.as_dict())
4342
self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.forecasts)
4443

4544
def test_time_independent_model_get_forecast(self):
@@ -123,7 +122,7 @@ def forecast_(_):
123122
name = "mock"
124123
fname = os.path.join(self._dir, "model.csv")
125124

126-
with patch("floatcsep.readers.ForecastParsers.csv", forecast_):
125+
with patch("floatcsep.readers.GriddedForecastParsers.csv", forecast_):
127126
model = self.init_model(name, fname)
128127
model.registry.build_tree([[start, end]])
129128
forecast = model.get_forecast(timestring)

0 commit comments

Comments
 (0)