Skip to content

Commit 42d007b

Browse files
committed
Removed unused files
1 parent defe0b3 commit 42d007b

File tree

5 files changed

+75
-9
lines changed

5 files changed

+75
-9
lines changed
-4.9 MB
Binary file not shown.

data/az/az-2048-3-true.pkl.gz

74.9 KB
Binary file not shown.

scripts/training/yield_prediction_az.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ def save_results(
2424

2525
def predict_az():
2626
root_path = Path(__file__).resolve().parent
27-
az_file_path = Path(root_path, "../../data/az/az-2048-3-true-props.pkl")
27+
az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl")
2828

2929
data = pickle.load(open(az_file_path, "rb"))
3030

3131
r2s = []
3232
maes = []
3333

3434
for i, split in enumerate(data):
35-
print(f"Evaluating split {i+1}/10 ...")
35+
print(f"Evaluating split {i + 1}/10 ...")
3636
X_train, y_train, X_valid, y_valid, X_test, y_test = (
3737
split["train"]["X"],
3838
split["train"]["y"],
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import pickle
2+
from pathlib import Path
3+
from typing import Tuple
4+
from statistics import stdev
5+
import numpy as np
6+
from xgboost import XGBRegressor
7+
from catboost import CatBoostRegressor
8+
from sklearn import svm
9+
from sklearn.ensemble import RandomForestRegressor
10+
from sklearn.metrics import r2_score, mean_absolute_error
11+
12+
13+
def save_results(
14+
set_name: str,
15+
split_id: str,
16+
file_name: str,
17+
ground_truth: np.ndarray,
18+
prediction: np.ndarray,
19+
) -> None:
20+
with open(f"{set_name}_{split_id}_{file_name}.csv", "w+") as f:
21+
for gt, pred in zip(ground_truth, prediction):
22+
f.write(f"{set_name},{split_id},{file_name},{gt},{pred}\n")
23+
24+
25+
def predict_az():
26+
root_path = Path(__file__).resolve().parent
27+
az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl")
28+
29+
data = pickle.load(open(az_file_path, "rb"))
30+
31+
r2s = []
32+
maes = []
33+
34+
for i, split in enumerate(data):
35+
print(f"Evaluating split {i + 1}/10 ...")
36+
X_train, y_train, X_valid, y_valid, X_test, y_test = (
37+
split["train"]["X"],
38+
split["train"]["y"],
39+
split["valid"]["X"],
40+
split["valid"]["y"],
41+
split["test"]["X"],
42+
split["test"]["y"],
43+
)
44+
45+
X_train = np.concatenate((X_train, X_valid))
46+
y_train = np.concatenate((y_train, y_valid))
47+
model = RandomForestRegressor(n_estimators=1000, random_state=42)
48+
model.fit(X_train, y_train)
49+
y_pred = model.predict(X_test)
50+
51+
r_squared = r2_score(y_test, y_pred)
52+
mae = mean_absolute_error(y_test, y_pred)
53+
print(f"Test {i + 1}", r_squared, mae / 100)
54+
r2s.append(r_squared)
55+
maes.append(mae)
56+
57+
print("Tests R2:", sum(r2s) / len(r2s), stdev(r2s))
58+
print("Tests MAE:", sum(maes) / (100 * len(maes)), stdev(maes) / 100)
59+
60+
61+
def main():
62+
predict_az()
63+
64+
65+
if __name__ == "__main__":
66+
main()

uv.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)