|
| 1 | +import pickle |
| 2 | +from pathlib import Path |
| 3 | +from typing import Tuple |
| 4 | +from statistics import stdev |
| 5 | +import numpy as np |
| 6 | +from xgboost import XGBRegressor |
| 7 | +from catboost import CatBoostRegressor |
| 8 | +from sklearn import svm |
| 9 | +from sklearn.ensemble import RandomForestRegressor |
| 10 | +from sklearn.metrics import r2_score, mean_absolute_error |
| 11 | + |
| 12 | + |
| 13 | +def save_results( |
| 14 | + set_name: str, |
| 15 | + split_id: str, |
| 16 | + file_name: str, |
| 17 | + ground_truth: np.ndarray, |
| 18 | + prediction: np.ndarray, |
| 19 | +) -> None: |
| 20 | + with open(f"{set_name}_{split_id}_{file_name}.csv", "w+") as f: |
| 21 | + for gt, pred in zip(ground_truth, prediction): |
| 22 | + f.write(f"{set_name},{split_id},{file_name},{gt},{pred}\n") |
| 23 | + |
| 24 | + |
| 25 | +def predict_az(): |
| 26 | + root_path = Path(__file__).resolve().parent |
| 27 | + az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl") |
| 28 | + |
| 29 | + data = pickle.load(open(az_file_path, "rb")) |
| 30 | + |
| 31 | + r2s = [] |
| 32 | + maes = [] |
| 33 | + |
| 34 | + for i, split in enumerate(data): |
| 35 | + print(f"Evaluating split {i + 1}/10 ...") |
| 36 | + X_train, y_train, X_valid, y_valid, X_test, y_test = ( |
| 37 | + split["train"]["X"], |
| 38 | + split["train"]["y"], |
| 39 | + split["valid"]["X"], |
| 40 | + split["valid"]["y"], |
| 41 | + split["test"]["X"], |
| 42 | + split["test"]["y"], |
| 43 | + ) |
| 44 | + |
| 45 | + X_train = np.concatenate((X_train, X_valid)) |
| 46 | + y_train = np.concatenate((y_train, y_valid)) |
| 47 | + model = RandomForestRegressor(n_estimators=1000, random_state=42) |
| 48 | + model.fit(X_train, y_train) |
| 49 | + y_pred = model.predict(X_test) |
| 50 | + |
| 51 | + r_squared = r2_score(y_test, y_pred) |
| 52 | + mae = mean_absolute_error(y_test, y_pred) |
| 53 | + print(f"Test {i + 1}", r_squared, mae / 100) |
| 54 | + r2s.append(r_squared) |
| 55 | + maes.append(mae) |
| 56 | + |
| 57 | + print("Tests R2:", sum(r2s) / len(r2s), stdev(r2s)) |
| 58 | + print("Tests MAE:", sum(maes) / (100 * len(maes)), stdev(maes) / 100) |
| 59 | + |
| 60 | + |
| 61 | +def main(): |
| 62 | + predict_az() |
| 63 | + |
| 64 | + |
| 65 | +if __name__ == "__main__": |
| 66 | + main() |
0 commit comments