Skip to content

Commit 4d678b0

Browse files
committed
refactor: evaluate always uses no_mask ETf, drop irr switching and dead code
1 parent 52df8bd commit 4d678b0

1 file changed

Lines changed: 39 additions & 155 deletions

File tree

examples/4_Flux_Network/evaluate.py

Lines changed: 39 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from swimrs.process.loop_fast import run_daily_loop_fast
2626
from swimrs.swim.config import ProjectConfig
2727

28-
IRR_THRESHOLD = 0.3
29-
3028
# Canonical exclusion list — sites with known data quality issues that should
3129
# not appear in any comparative evaluation. Keep this list general so new
3230
# exclusions can be added without ad hoc filters elsewhere.
@@ -141,60 +139,7 @@ def load_flux_et(fid, flux_dir):
141139
return pd.Series(dtype=float)
142140

143141

144-
def load_ssebop_etf(container, fid, irr_data):
145-
"""Load SSEBop NHM ETf from the container with year-appropriate mask.
146-
147-
Returns pd.Series of ETf values, or None if no data available.
148-
"""
149-
field_irr = irr_data.get(fid, {})
150-
irr_years = set()
151-
for k, v in field_irr.items():
152-
if k == "fallow_years":
153-
continue
154-
try:
155-
if isinstance(v, dict) and v.get("f_irr", 0.0) >= IRR_THRESHOLD:
156-
irr_years.add(int(k))
157-
except (ValueError, TypeError):
158-
continue
159-
160-
# Load both masks
161-
etf_inv = etf_irr = None
162-
for mask in ["inv_irr", "irr"]:
163-
etf_path = f"remote_sensing/etf/landsat/ssebop/{mask}"
164-
try:
165-
etf_df = container.query.dataframe(etf_path, fields=[fid])
166-
if fid in etf_df.columns:
167-
series = etf_df[fid]
168-
if mask == "inv_irr":
169-
etf_inv = series
170-
else:
171-
etf_irr = series
172-
except Exception:
173-
pass
174-
175-
inv_valid = etf_inv is not None and etf_inv.notna().any()
176-
irr_valid = etf_irr is not None and etf_irr.notna().any()
177-
178-
if not inv_valid and not irr_valid:
179-
return None
180-
181-
# Default to inv_irr, switch to irr for irrigated years
182-
if inv_valid:
183-
combined = etf_inv.copy()
184-
else:
185-
combined = pd.Series(np.nan, index=etf_irr.index)
186-
187-
if irr_valid and irr_years:
188-
irr_mask = combined.index.year.isin(irr_years)
189-
combined.loc[irr_mask] = etf_irr.loc[irr_mask]
190-
191-
if not inv_valid and irr_valid:
192-
combined = etf_irr.copy()
193-
194-
return combined
195-
196-
197-
def load_ssebop_etf_no_mask(container, fid):
142+
def load_ssebop_etf(container, fid):
198143
"""Load SSEBop NHM ETf from the no_mask path (full footprint)."""
199144
etf_path = "remote_sensing/etf/landsat/ssebop/no_mask"
200145
try:
@@ -219,7 +164,7 @@ def calc_metrics(obs, mod):
219164
return {"n": len(obs), "r2": r2, "r": r, "rmse": rmse, "bias": bias}
220165

221166

222-
def evaluate(cfg, container, par_csv, fids, flux_dir, no_mask=False):
167+
def evaluate(cfg, container, par_csv, fids, flux_dir):
223168
"""Run calibrated model and evaluate against flux tower ET and SSEBop NHM.
224169
225170
Both SWIM and SSEBop are scored on the exact same set of days per site
@@ -231,16 +176,6 @@ def evaluate(cfg, container, par_csv, fids, flux_dir, no_mask=False):
231176
fids = apply_exclusions(fids)
232177
print(f"Evaluating {len(fids)} fields from {par_csv}")
233178

234-
# Load irrigation data from container
235-
if no_mask:
236-
irr_data = {}
237-
else:
238-
try:
239-
dynamics = container.export._get_dynamics_dict(fids)
240-
irr_data = dynamics.get("irr", {})
241-
except Exception:
242-
irr_data = {}
243-
244179
calibrated_params = parse_pest_params(par_csv, fids)
245180
missing = [f for f in fids if f not in calibrated_params]
246181
if missing:
@@ -270,10 +205,7 @@ def evaluate(cfg, container, par_csv, fids, flux_dir, no_mask=False):
270205
swim_vals = swim_et.loc[common].values
271206

272207
# SSEBop NHM ET (interpolated ETf × ETo)
273-
if no_mask:
274-
etf_series = load_ssebop_etf_no_mask(container, fid)
275-
else:
276-
etf_series = load_ssebop_etf(container, fid, irr_data)
208+
etf_series = load_ssebop_etf(container, fid)
277209
if etf_series is not None:
278210
etf_interp = etf_series.interpolate(method="linear")
279211
ssebop_et = etf_interp * etref
@@ -339,44 +271,16 @@ def evaluate(cfg, container, par_csv, fids, flux_dir, no_mask=False):
339271
return metrics_df
340272

341273

342-
def _monthly_sum(daily_series, max_interp=5):
343-
"""Resample daily series to monthly sums, interpolating up to max_interp gaps per month.
344-
345-
Returns monthly Series with NaN for months that had more than max_interp missing days.
346-
"""
347-
monthly = daily_series.resample("MS").apply(lambda grp: _interp_and_sum(grp, max_interp))
348-
return monthly
349-
350-
351-
def _interp_and_sum(grp, max_interp):
352-
"""Interpolate up to max_interp NaNs in a month, then sum. Return NaN if too many gaps."""
353-
n_missing = grp.isna().sum()
354-
if n_missing > max_interp:
355-
return np.nan
356-
if n_missing > 0:
357-
grp = grp.interpolate(method="linear", limit=max_interp)
358-
return grp.sum()
359-
360-
361-
def evaluate_monthly(cfg, container, par_csv, fids, flux_dir, max_interp=5, no_mask=False):
274+
def evaluate_monthly(cfg, container, par_csv, fids, flux_dir):
362275
"""Monthly aggregation of ET evaluation with strictly paired months.
363276
364-
Resamples daily ET to monthly totals (mm/month), interpolating up to
365-
max_interp missing flux days per month before summing. Both SWIM and SSEBop
366-
are scored on the exact same set of months per site (paired evaluation).
277+
Intersects daily indices first, then aggregates to monthly sums. Only
278+
months with at least 20 valid daily flux observations are kept. Both SWIM
279+
and SSEBop are scored on the exact same set of months per site.
367280
"""
368281
fids = apply_exclusions(fids)
369282
print(f"Monthly evaluation: {len(fids)} fields from {par_csv}")
370283

371-
if no_mask:
372-
irr_data = {}
373-
else:
374-
try:
375-
dynamics = container.export._get_dynamics_dict(fids)
376-
irr_data = dynamics.get("irr", {})
377-
except Exception:
378-
irr_data = {}
379-
380284
calibrated_params = parse_pest_params(par_csv, fids)
381285
missing = [f for f in fids if f not in calibrated_params]
382286
if missing:
@@ -395,53 +299,53 @@ def evaluate_monthly(cfg, container, par_csv, fids, flux_dir, max_interp=5, no_m
395299
swim_et = model_df["et_act"]
396300
etref = model_df["etref"]
397301

398-
# Build aligned daily frame on common dates
399-
common_dates = swim_et.index.intersection(flux_et.index)
400-
if len(common_dates) < 30:
302+
# Intersect daily indices first, then aggregate to monthly
303+
daily_common = swim_et.index.intersection(flux_et.index)
304+
if len(daily_common) < 30:
401305
continue
402306

403-
# Full daily index spanning the overlap
404-
full_idx = pd.date_range(common_dates.min(), common_dates.max(), freq="D")
405-
flux_daily = flux_et.reindex(full_idx)
406-
swim_daily = swim_et.reindex(full_idx)
307+
swim_daily = swim_et.loc[daily_common]
308+
flux_daily = flux_et.loc[daily_common]
407309

408-
# Monthly sums (ET) with interpolation limit on flux
409-
flux_monthly = _monthly_sum(flux_daily, max_interp)
310+
# Aggregate to monthly totals
410311
swim_monthly = swim_daily.resample("MS").sum()
312+
flux_monthly = flux_daily.resample("MS").sum()
411313

412-
# SSEBop monthly ET
413-
if no_mask:
414-
etf_series = load_ssebop_etf_no_mask(container, fid)
415-
else:
416-
etf_series = load_ssebop_etf(container, fid, irr_data)
314+
# Only keep months with >= 20 valid daily flux obs
315+
flux_count = flux_daily.resample("MS").count()
316+
valid_months = flux_count[flux_count >= 20].index
317+
swim_monthly = swim_monthly.loc[swim_monthly.index.isin(valid_months)]
318+
flux_monthly = flux_monthly.loc[flux_monthly.index.isin(valid_months)]
319+
320+
# SSEBop monthly ET on the same daily common index
321+
etf_series = load_ssebop_etf(container, fid)
417322
if etf_series is not None:
418-
etf_interp = etf_series.reindex(full_idx).interpolate(method="linear")
419-
ssebop_daily = etf_interp * etref.reindex(full_idx)
323+
etf_interp = etf_series.reindex(daily_common).interpolate(method="linear")
324+
ssebop_daily = etf_interp * etref.reindex(daily_common)
420325
ssebop_monthly = ssebop_daily.resample("MS").sum()
421326
else:
422327
ssebop_monthly = pd.Series(np.nan, index=swim_monthly.index)
423328

424-
# Strictly paired months: all three (flux, swim, ssebop) must be finite
425-
all_months = flux_monthly.index.union(swim_monthly.index).union(ssebop_monthly.index)
426-
flux_vals = flux_monthly.reindex(all_months)
427-
swim_vals = swim_monthly.reindex(all_months)
428-
ssebop_vals = ssebop_monthly.reindex(all_months)
429-
430-
paired_mask = flux_vals.notna() & swim_vals.notna() & ssebop_vals.notna()
431-
paired_months = all_months[paired_mask]
329+
# Strictly paired months: flux, swim, and ssebop all finite
330+
all_idx = flux_monthly.index
331+
ssebop_on_idx = ssebop_monthly.reindex(all_idx)
332+
paired_mask = (
333+
flux_monthly.notna() & swim_monthly.reindex(all_idx).notna() & ssebop_on_idx.notna()
334+
)
335+
paired_months = all_idx[paired_mask]
432336
n_paired = len(paired_months)
433337

434338
if n_paired < 6:
435339
continue
436340

437-
obs = flux_vals.loc[paired_months].values
341+
obs = flux_monthly.loc[paired_months].values
438342
row = {"fid": fid, "n_months": n_paired}
439343

440-
m = calc_metrics(obs, swim_vals.loc[paired_months].values)
344+
m = calc_metrics(obs, swim_monthly.reindex(paired_months).values)
441345
for k in ["r2", "r", "rmse", "bias"]:
442346
row[f"{k}_swim"] = m[k]
443347

444-
m = calc_metrics(obs, ssebop_vals.loc[paired_months].values)
348+
m = calc_metrics(obs, ssebop_on_idx.loc[paired_months].values)
445349
for k in ["r2", "r", "rmse", "bias"]:
446350
row[f"{k}_ssebop"] = m[k]
447351

@@ -481,7 +385,7 @@ def evaluate_monthly(cfg, container, par_csv, fids, flux_dir, max_interp=5, no_m
481385
return metrics_df
482386

483387

484-
def evaluate_etf(cfg, container, par_csv, fids, no_mask=False):
388+
def evaluate_etf(cfg, container, par_csv, fids):
485389
"""Compare SWIM ETf against SSEBop NHM ETf at Landsat capture dates.
486390
487391
Isolates model skill from ETo conversion issues by comparing ETf directly.
@@ -494,26 +398,13 @@ def evaluate_etf(cfg, container, par_csv, fids, no_mask=False):
494398
calibrated_params = parse_pest_params(par_csv, fids)
495399
model_results = run_calibrated_model(cfg, container, fids, calibrated_params)
496400

497-
# Load irrigation data for mask selection
498-
if no_mask:
499-
irr_data = {}
500-
else:
501-
try:
502-
dynamics = container.export._get_dynamics_dict(fids)
503-
irr_data = dynamics.get("irr", {})
504-
except Exception:
505-
irr_data = {}
506-
507401
rows = []
508402
for fid in fids:
509403
if fid not in model_results:
510404
continue
511405
swim_etf = model_results[fid]["etf_model"]
512406

513-
if no_mask:
514-
etf_series = load_ssebop_etf_no_mask(container, fid)
515-
else:
516-
etf_series = load_ssebop_etf(container, fid, irr_data)
407+
etf_series = load_ssebop_etf(container, fid)
517408
if etf_series is None:
518409
continue
519410

@@ -598,11 +489,6 @@ def find_par_csv(results_dir, project_name):
598489
default=None,
599490
help="Override container path (default: derived from config)",
600491
)
601-
parser.add_argument(
602-
"--no-mask",
603-
action="store_true",
604-
help="Use no_mask ETf (full footprint) instead of irr/inv_irr switching",
605-
)
606492
args = parser.parse_args()
607493

608494
cfg = load_config()
@@ -631,15 +517,13 @@ def find_par_csv(results_dir, project_name):
631517

632518
try:
633519
if args.monthly:
634-
metrics = evaluate_monthly(
635-
cfg, container, par_csv, fids, flux_dir, no_mask=args.no_mask
636-
)
520+
metrics = evaluate_monthly(cfg, container, par_csv, fids, flux_dir)
637521
out_csv = os.path.join(results_dir, "evaluation_monthly_metrics.csv")
638522
elif args.etf:
639-
metrics = evaluate_etf(cfg, container, par_csv, fids, no_mask=args.no_mask)
523+
metrics = evaluate_etf(cfg, container, par_csv, fids)
640524
out_csv = os.path.join(results_dir, "evaluation_etf_metrics.csv")
641525
else:
642-
metrics = evaluate(cfg, container, par_csv, fids, flux_dir, no_mask=args.no_mask)
526+
metrics = evaluate(cfg, container, par_csv, fids, flux_dir)
643527
out_csv = os.path.join(results_dir, "evaluation_metrics.csv")
644528
os.makedirs(results_dir, exist_ok=True)
645529
metrics.to_csv(out_csv)

0 commit comments

Comments
 (0)