Skip to content

Commit 65600d4

Browse files
committed
fix(sktime-quant): enforce rule provenance/errors, add classifier status+cache, trim heavy sample data
1 parent e753f67 commit 65600d4

8 files changed

Lines changed: 101 additions & 11653 deletions

File tree

sktime_quant/Ingest-outside-code/smalldata/^INDIAVIX.csv

Lines changed: 0 additions & 3963 deletions
This file was deleted.

sktime_quant/Ingest-outside-code/smalldata/^NSEBANK.csv

Lines changed: 0 additions & 3688 deletions
This file was deleted.

sktime_quant/Ingest-outside-code/smalldata/^NSEI.csv

Lines changed: 0 additions & 3964 deletions
This file was deleted.

sktime_quant/backtest/walkforward.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sktime_quant.risk.metrics import max_drawdown
1919
from sktime_quant.strategy.blender import blend_signals
2020
from sktime_quant.strategy.classifier import predict_classifier_signal_at
21-
from sktime_quant.strategy.rule_dsl import evaluate_rules_signal, load_rules_yaml
21+
from sktime_quant.strategy.rule_dsl import evaluate_rules_signal
2222

2323

2424
@dataclass(slots=True)
@@ -251,12 +251,6 @@ def run(
251251
assets_list = sorted(market["asset"].astype(str).unique().tolist())
252252
strategy_mode = str(getattr(strategy_config, "mode", "forecast_only"))
253253
strategy_rules = list(getattr(strategy_config, "rules", []) or [])
254-
strategy_rules_path = getattr(strategy_config, "rules_path", None)
255-
if strategy_rules_path:
256-
try:
257-
strategy_rules = load_rules_yaml(strategy_rules_path)
258-
except Exception:
259-
pass
260254
rule_chain = str(getattr(strategy_config, "rule_chain", "any"))
261255
classifier_type = str(getattr(strategy_config, "classifier_type", "random_forest"))
262256
classifier_min_train = int(
@@ -281,6 +275,7 @@ def run(
281275
.set_index("timestamp")["close"]
282276
.astype(float)
283277
)
278+
classifier_cache: dict[tuple[str, str, str, int, float], tuple[int, float, str]] = {}
284279
y = self._series_for_asset(market, asset)
285280
if len(y) <= backtest_config.window_length + backtest_config.horizon + 2:
286281
continue
@@ -436,6 +431,7 @@ def run(
436431
signals_classifier: list[int] = []
437432
signals_blended: list[int] = []
438433
classifier_confidence_vals: list[float] = []
434+
classifier_status_vals: list[str] = []
439435
interval_sources: list[str] = []
440436

441437
for _, row in result.iterrows():
@@ -468,14 +464,25 @@ def run(
468464
chain=rule_chain,
469465
default_signal=0,
470466
)
471-
classifier_signal, cls_conf = predict_classifier_signal_at(
472-
feature_frame=strategy_feature_frame,
473-
close_series=strategy_close,
474-
cutoff=cutoff,
475-
classifier_type=classifier_type,
476-
min_train_samples=classifier_min_train,
477-
probability_threshold=classifier_prob_threshold,
467+
cache_key = (
468+
str(asset),
469+
pd.Timestamp(cutoff).isoformat(),
470+
classifier_type,
471+
classifier_min_train,
472+
round(classifier_prob_threshold, 4),
478473
)
474+
cached = classifier_cache.get(cache_key)
475+
if cached is None:
476+
cached = predict_classifier_signal_at(
477+
feature_frame=strategy_feature_frame,
478+
close_series=strategy_close,
479+
cutoff=cutoff,
480+
classifier_type=classifier_type,
481+
min_train_samples=classifier_min_train,
482+
probability_threshold=classifier_prob_threshold,
483+
)
484+
classifier_cache[cache_key] = cached
485+
classifier_signal, cls_conf, cls_status = cached
479486
blended_signal = blend_signals(
480487
forecast_signal=signal,
481488
rule_signal=rule_signal,
@@ -512,6 +519,7 @@ def run(
512519
signals_classifier.append(classifier_signal)
513520
signals_blended.append(blended_signal)
514521
classifier_confidence_vals.append(cls_conf)
522+
classifier_status_vals.append(cls_status)
515523
interval_sources.append(interval_source)
516524

517525
result["fold_return"] = fold_returns
@@ -524,6 +532,7 @@ def run(
524532
result["signal_classifier"] = signals_classifier
525533
result["signal_blended"] = signals_blended
526534
result["classifier_confidence"] = classifier_confidence_vals
535+
result["classifier_status"] = classifier_status_vals
527536
result["interval_source"] = interval_sources
528537

529538
returns = result["fold_return"].fillna(0.0)
@@ -543,12 +552,17 @@ def run(
543552

544553
excluded = False
545554
exclusion_reason = ""
555+
classifier_status_series = result["classifier_status"].astype(str)
556+
classifier_missing = bool((classifier_status_series == "sklearn_missing").any())
546557
if failure_rate > backtest_config.max_failure_rate:
547558
excluded = True
548559
exclusion_reason = "high_failure_rate"
549560
elif empirical_coverage < backtest_config.confidence_floor:
550561
excluded = True
551562
exclusion_reason = "low_empirical_coverage"
563+
elif strategy_mode in {"classifier_only", "blended"} and classifier_missing:
564+
excluded = True
565+
exclusion_reason = "classifier_unavailable"
552566

553567
metrics_rows.append(
554568
{
@@ -570,6 +584,8 @@ def run(
570584
"strategy_mode": strategy_mode,
571585
"blend_policy": blend_policy,
572586
"classifier_type": classifier_type,
587+
"classifier_unavailable": classifier_missing,
588+
"classifier_status_counts": classifier_status_series.value_counts().to_dict(),
573589
}
574590
)
575591

@@ -587,6 +603,7 @@ def run(
587603
"signal_classifier",
588604
"signal_blended",
589605
"classifier_confidence",
606+
"classifier_status",
590607
"interval_source",
591608
]
592609
].copy()

sktime_quant/pipelines/orchestrator.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from sktime_quant.portfolio.optimizer import AllocationResult, PortfolioEngine
3535
from sktime_quant.reporting.run_report import write_run_report
36-
from sktime_quant.strategy.rule_dsl import save_rules_yaml
36+
from sktime_quant.strategy.rule_dsl import load_rules_yaml, save_rules_yaml
3737

3838

3939
@dataclass(slots=True)
@@ -210,6 +210,24 @@ def _effective_data_config(self, cfg: AppConfig, paths: dict[str, Path]):
210210
return cfg.data
211211
return replace(cfg.data, start=prev_ts)
212212

213+
def _resolve_effective_strategy_config(self, cfg: AppConfig) -> tuple[object, list[dict[str, object]], str]:
214+
rules_source = "inline"
215+
effective_rules = list(cfg.strategy.rules or [])
216+
if cfg.strategy.rules_path:
217+
if not Path(cfg.strategy.rules_path).exists():
218+
raise ValueError(
219+
f"Failed to load strategy rules from path '{cfg.strategy.rules_path}': FileNotFoundError: file does not exist"
220+
)
221+
try:
222+
effective_rules = load_rules_yaml(cfg.strategy.rules_path)
223+
rules_source = f"path:{cfg.strategy.rules_path}"
224+
except Exception as exc:
225+
raise ValueError(
226+
f"Failed to load strategy rules from path '{cfg.strategy.rules_path}': {type(exc).__name__}: {exc}"
227+
) from exc
228+
strategy_cfg = replace(cfg.strategy, rules=effective_rules)
229+
return strategy_cfg, effective_rules, rules_source
230+
213231
def _write_incremental_state(self, cfg: AppConfig, paths: dict[str, Path], market: pd.DataFrame) -> None:
214232
if not cfg.data.incremental_mode or market.empty:
215233
return
@@ -359,6 +377,7 @@ def notify(payload: dict[str, object]) -> None:
359377

360378
notify({"stage": "start", "event": "run_start", "run_id": cfg.run_id})
361379
paths = self._artifact_paths(cfg)
380+
strategy_cfg, effective_rules, strategy_rules_source = self._resolve_effective_strategy_config(cfg)
362381
notify({"stage": "data", "event": "loading_data"})
363382
effective_data_cfg = self._effective_data_config(cfg, paths)
364383
market, exog = self.data_provider.load_history(effective_data_cfg)
@@ -373,15 +392,13 @@ def notify(payload: dict[str, object]) -> None:
373392
min_points_for_freq=cfg.execution.data_quality_min_points_for_freq,
374393
)
375394
paths["data_quality"].write_text(json.dumps(data_quality, indent=2), encoding="utf-8")
376-
strategy_payload = asdict(cfg.strategy)
395+
strategy_payload = asdict(strategy_cfg)
396+
strategy_payload["rules_source"] = strategy_rules_source
377397
paths["strategy_config"].write_text(
378398
json.dumps(strategy_payload, indent=2, default=str),
379399
encoding="utf-8",
380400
)
381-
try:
382-
save_rules_yaml(paths["strategy_rules"], cfg.strategy.rules)
383-
except Exception:
384-
pass
401+
save_rules_yaml(paths["strategy_rules"], effective_rules)
385402
notify(
386403
{
387404
"stage": "data",
@@ -418,11 +435,12 @@ def notify(payload: dict[str, object]) -> None:
418435
"model_governance_path": str(paths["model_governance"]),
419436
"strategy_config_path": str(paths["strategy_config"]),
420437
"strategy_rules_path": str(paths["strategy_rules"]),
438+
"strategy_rules_source": strategy_rules_source,
421439
"timestamp_utc": datetime.now(UTC).isoformat(),
422440
"allocation_diagnostics": {},
423441
"execution_diagnostics": self.order_exporter._empty_diagnostics(),
424442
"governance_alert_count": 0,
425-
"strategy_mode": cfg.strategy.mode,
443+
"strategy_mode": strategy_cfg.mode,
426444
"run_status": "no_new_data",
427445
"message": "No rows available after applying ingestion filters/incremental window.",
428446
}
@@ -475,7 +493,7 @@ def notify(payload: dict[str, object]) -> None:
475493
model_names=candidate_models,
476494
backtest_config=cfg.backtest,
477495
exog=exog_model,
478-
strategy_config=cfg.strategy,
496+
strategy_config=strategy_cfg,
479497
holiday_by_asset=holiday_by_asset,
480498
progress_hook=progress_hook,
481499
)
@@ -570,8 +588,9 @@ def notify(payload: dict[str, object]) -> None:
570588
"execution_diagnostics": order_diagnostics,
571589
"strategy_config_path": str(paths["strategy_config"]),
572590
"strategy_rules_path": str(paths["strategy_rules"]),
573-
"strategy_mode": cfg.strategy.mode,
574-
"strategy_blend_policy": cfg.strategy.blend_policy,
591+
"strategy_rules_source": strategy_rules_source,
592+
"strategy_mode": strategy_cfg.mode,
593+
"strategy_blend_policy": strategy_cfg.blend_policy,
575594
"forecast_update_mode": cfg.model.update_mode,
576595
"forecast_update_status_counts": forecast.predictions.get(
577596
"update_status", pd.Series(dtype=str)

sktime_quant/strategy/classifier.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,54 +39,54 @@ def predict_classifier_signal_at(
3939
classifier_type: str = "random_forest",
4040
min_train_samples: int = 30,
4141
probability_threshold: float = 0.55,
42-
) -> tuple[int, float]:
42+
) -> tuple[int, float, str]:
4343
if feature_frame is None or feature_frame.empty or close_series is None or close_series.empty:
44-
return 0, 0.0
44+
return 0, 0.0, "no_features_or_close"
4545

4646
frame = feature_frame.copy()
4747
if not isinstance(frame.index, pd.DatetimeIndex):
48-
return 0, 0.0
48+
return 0, 0.0, "invalid_feature_index"
4949
frame = frame.sort_index()
5050
close = close_series.copy().sort_index()
5151
if not isinstance(close.index, pd.DatetimeIndex):
52-
return 0, 0.0
52+
return 0, 0.0, "invalid_close_index"
5353

5454
common = frame.index.intersection(close.index)
5555
if len(common) < max(5, int(min_train_samples)):
56-
return 0, 0.0
56+
return 0, 0.0, "insufficient_common_samples"
5757

5858
frame = frame.loc[common]
5959
y_cls = _build_label_from_close(close.loc[common])
6060
data = frame.copy()
6161
data["target"] = y_cls
6262
data = data.dropna(how="any")
6363
if data.empty:
64-
return 0, 0.0
64+
return 0, 0.0, "empty_after_dropna"
6565

6666
cutoff = pd.Timestamp(cutoff)
6767
train = data[data.index < cutoff]
6868
test = data[data.index == cutoff]
6969
if test.empty:
7070
prior = data[data.index <= cutoff]
7171
if prior.empty:
72-
return 0, 0.0
72+
return 0, 0.0, "no_test_row"
7373
test = prior.tail(1)
7474
train = data[data.index < test.index[0]]
7575

7676
if len(train) < max(5, int(min_train_samples)):
77-
return 0, 0.0
77+
return 0, 0.0, "insufficient_train_samples"
7878

7979
x_cols = [c for c in train.columns if c != "target"]
8080
if not x_cols:
81-
return 0, 0.0
81+
return 0, 0.0, "no_feature_columns"
8282

8383
x_train = train[x_cols]
8484
y_train = train["target"].astype(int)
8585
x_test = test[x_cols]
8686

8787
model = _make_classifier(classifier_type)
8888
if model is None:
89-
return 0, 0.0
89+
return 0, 0.0, "sklearn_missing"
9090
try:
9191
model.fit(x_train, y_train)
9292
pred = int(model.predict(x_test)[0])
@@ -98,7 +98,7 @@ def predict_classifier_signal_at(
9898
else:
9999
conf = 0.5
100100
if conf < float(probability_threshold):
101-
return 0, conf
102-
return pred if pred in {-1, 0, 1} else 0, conf
101+
return 0, conf, "low_confidence"
102+
return (pred if pred in {-1, 0, 1} else 0), conf, "ok"
103103
except Exception:
104-
return 0, 0.0
104+
return 0, 0.0, "classifier_error"

sktime_quant/tests/test_orchestrator_integration.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22
import json
33
from pathlib import Path
4+
import pytest
45

56
from sktime_quant.config.schema import AppConfig
67
from sktime_quant.pipelines.orchestrator import Orchestrator
@@ -50,6 +51,7 @@ def test_orchestrator_end_to_end_csv(tmp_path):
5051
assert "strategy_mode" in summary
5152
assert "strategy_config_path" in summary
5253
assert "strategy_rules_path" in summary
54+
assert "strategy_rules_source" in summary
5355
governance = json.loads((tmp_path / "results" / "reports" / "it_run_model_governance.json").read_text(encoding="utf-8"))
5456
assert "alerts" in governance
5557

@@ -84,6 +86,29 @@ def test_orchestrator_handles_no_new_data_incremental_window(tmp_path):
8486
assert result.report_path.endswith(".md")
8587

8688

89+
def test_orchestrator_raises_on_invalid_rules_path(tmp_path):
90+
n = 80
91+
df = pd.DataFrame(
92+
{
93+
"timestamp": pd.date_range("2023-01-01", periods=n, freq="D"),
94+
"asset": ["A"] * n,
95+
"close": [100 + i * 0.1 for i in range(n)],
96+
}
97+
)
98+
csv_path = tmp_path / "market.csv"
99+
df.to_csv(csv_path, index=False)
100+
101+
cfg = AppConfig()
102+
cfg.run_id = "it_bad_rules"
103+
cfg.data.source_type = "csv"
104+
cfg.data.csv_path = str(csv_path)
105+
cfg.execution.output_dir = str(tmp_path / "results")
106+
cfg.strategy.rules_path = str(tmp_path / "missing_rules.yaml")
107+
108+
with pytest.raises(ValueError, match="Failed to load strategy rules"):
109+
Orchestrator().run(cfg)
110+
111+
87112
def test_data_quality_reports_frequency_drift_and_missing_bars(tmp_path):
88113
base_dates = pd.date_range("2023-01-01", periods=40, freq="D")
89114
a_dates = base_dates.delete(10) # introduce one missing bar in daily sequence

sktime_quant/tests/test_strategy_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_classifier_predict_signal_at():
4747
},
4848
index=idx,
4949
)
50-
sig, conf = predict_classifier_signal_at(
50+
sig, conf, status = predict_classifier_signal_at(
5151
feature_frame=features,
5252
close_series=close,
5353
cutoff=idx[100],
@@ -57,6 +57,7 @@ def test_classifier_predict_signal_at():
5757
)
5858
assert sig in {-1, 0, 1}
5959
assert 0.0 <= conf <= 1.0
60+
assert isinstance(status, str)
6061

6162

6263
def test_walkforward_strategy_modes_emit_signal_columns():
@@ -90,3 +91,4 @@ def test_walkforward_strategy_modes_emit_signal_columns():
9091
assert "signal_rule" in result.fold_predictions.columns
9192
assert "signal_classifier" in result.fold_predictions.columns
9293
assert "signal_blended" in result.fold_predictions.columns
94+
assert "classifier_status" in result.fold_predictions.columns

0 commit comments

Comments
 (0)