Skip to content

Commit e083020

Browse files
committed
Updated scripts
1 parent 385d746 commit e083020

File tree

5 files changed

+27
-19
lines changed

5 files changed

+27
-19
lines changed
4.9 MB
Binary file not shown.

data/az/az-2048-3-true.pkl.gz

1.99 MB
Binary file not shown.

scripts/encoding/encode_az_reactions.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_az_rxns(fold_idx: int = 0):
5252
for data, split in [(train, "train"), (valid, "valid"), (test, "test")]:
5353
X, mapping = DrfpEncoder.encode(
5454
data.smiles.to_numpy(),
55-
n_folded_length=10240,
55+
n_folded_length=2048,
5656
radius=3,
5757
rings=True,
5858
mapping=True,
@@ -66,18 +66,19 @@ def get_az_rxns(fold_idx: int = 0):
6666

6767
y = data["yield"].to_numpy()
6868

69-
X_props = data.drop(
70-
columns=[
71-
"yield",
72-
"reactant_smiles",
73-
"solvent_smiles",
74-
"base_smiles",
75-
"product_smiles",
76-
"id",
77-
]
78-
).to_numpy()
79-
80-
X = np.concatenate((X, X_props), axis=1)
69+
# X_props = data.drop(
70+
# columns=[
71+
# "yield",
72+
# "reactant_smiles",
73+
# "solvent_smiles",
74+
# "base_smiles",
75+
# "product_smiles",
76+
# "id",
77+
# "smiles",
78+
# ]
79+
# ).to_numpy()
80+
#
81+
# X = np.concatenate((X, X_props), axis=1)
8182

8283
output_splits[split] = {
8384
"X": X,
@@ -88,8 +89,8 @@ def get_az_rxns(fold_idx: int = 0):
8889

8990
output.append(output_splits)
9091

91-
out_file_name = Path(az_path, f"az-10240-3-true-props.pkl")
92-
out_file_name_gz = Path(az_path, f"az-10240-3-true-props.pkl.gz")
92+
out_file_name = Path(az_path, f"az-2048-3-true.pkl")
93+
out_file_name_gz = Path(az_path, f"az-2048-3-true.pkl.gz")
9394

9495
with open(out_file_name, "wb+") as f:
9596
pickle.dump(output, f)

scripts/training/yield_prediction_az.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from statistics import stdev
55
import numpy as np
66
from xgboost import XGBRegressor
7+
from catboost import CatBoostRegressor
8+
from sklearn import svm
9+
from sklearn.ensemble import RandomForestRegressor
710
from sklearn.metrics import r2_score, mean_absolute_error
811

912

@@ -21,7 +24,7 @@ def save_results(
2124

2225
def predict_az():
2326
root_path = Path(__file__).resolve().parent
24-
az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl")
27+
az_file_path = Path(root_path, "../../data/az/az-2048-3-true-props.pkl")
2528

2629
data = pickle.load(open(az_file_path, "rb"))
2730

@@ -59,9 +62,13 @@ def predict_az():
5962
)
6063

6164
y_pred = model.predict(X_test, iteration_range=(0, model.best_iteration))
62-
# y_pred[y_pred < 0.0] = 0.0
6365

64-
# save_results("az", split, sample_file, y_test, y_pred)
66+
# X_train = np.concatenate((X_train, X_valid))
67+
# y_train = np.concatenate((y_train, y_valid))
68+
# model = RandomForestRegressor(n_estimators=1000, random_state=42)
69+
# model.fit(X_train, y_train)
70+
# y_pred = model.predict(X_test)
71+
6572
r_squared = r2_score(y_test, y_pred)
6673
mae = mean_absolute_error(y_test, y_pred)
6774
print(f"Test {i + 1}", r_squared, mae / 100)

scripts/training/yield_prediction_az_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def save_results(
3737

3838
def predict_az():
3939
root_path = Path(__file__).resolve().parent
40-
az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl")
40+
az_file_path = Path(root_path, "../../data/az/az-10240-3-true.pkl")
4141

4242
data = pickle.load(open(az_file_path, "rb"))
4343

0 commit comments

Comments
 (0)