Skip to content

update #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: mol
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6771b18
Merge pull request #7 from yuanqing-wang/data
yuanqing-wang Jun 30, 2023
fbe9db6
tune
yuanqing-wang Sep 25, 2023
3592a61
Merge pull request #15 from yuanqing-wang/mol
yuanqing-wang Sep 25, 2023
2d402ba
tuning
yuanqing-wang Sep 26, 2023
8355934
Merge branch 'consistency' of https://github.com/yuanqing-wang/bronx …
yuanqing-wang Sep 26, 2023
8b61b38
update gitignore
yuanqing-wang Sep 26, 2023
266907b
Merge branch 'consistency' of https://github.com/yuanqing-wang/bronx …
yuanqing-wang Sep 26, 2023
87a2c7a
tune
yuanqing-wang Sep 27, 2023
544f924
tune
yuanqing-wang Sep 27, 2023
07827fc
Merge branch 'consistency' of https://github.com/yuanqing-wang/bronx …
yuanqing-wang Sep 27, 2023
7938bed
tune
yuanqing-wang Sep 27, 2023
4491849
Merge pull request #16 from yuanqing-wang/consistency
yuanqing-wang Sep 29, 2023
5e3b405
tuning using hyperopt
yuanqing-wang Oct 2, 2023
a4cd105
tuning
yuanqing-wang Oct 4, 2023
be1d744
tuning
yuanqing-wang Oct 5, 2023
18cda9c
consistency tuning
yuanqing-wang Nov 11, 2023
5a6c31d
anneal
yuanqing-wang Nov 14, 2023
abdbb62
tuning
yuanqing-wang Nov 18, 2023
b331d20
anneal
yuanqing-wang Nov 18, 2023
0b2d160
tune
yuanqing-wang Nov 20, 2023
6896600
graph regression
yuanqing-wang Nov 20, 2023
9a0cb11
anneal
yuanqing-wang Nov 21, 2023
f38ffa7
Merge branch 'anneal' of https://github.com/yuanqing-wang/bronx into …
yuanqing-wang Nov 21, 2023
d0e3162
run script
yuanqing-wang Nov 21, 2023
3da2179
plotting
yuanqing-wang Nov 22, 2023
051e7a5
plot notebook
yuanqing-wang Nov 22, 2023
719e2b7
Merge pull request #18 from yuanqing-wang/anneal
yuanqing-wang Dec 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,7 @@ dmypy.json
*.txt
events*
*.pt
core.*
core.*

*dataset*
*.bin
18 changes: 4 additions & 14 deletions bronx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def forward(self, t, x):
h0 = h
g = self.g.local_var()
g.edata["e"] = e
g.ndata["h"] = h
g.ndata["h"] = h
g.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
h = g.ndata["h"]
# h = h / g.in_degrees().float().clamp(min=1).view(-1, *((1,) * (h.dim()-1)))
# h = h.tanh()
h = h - h0 * self.gamma
if self.h0 is not None:
Expand Down Expand Up @@ -62,7 +63,6 @@ def forward(self, g, h, e):
e, h = e.swapaxes(0, 1), h.swapaxes(0, 1)

h = h.reshape(*h.shape[:-1], e.shape[-2], -1)
# e = edge_softmax(g, e)
g.edata["e"] = e
g.update_all(fn.copy_e("e", "m"), fn.sum("m", "e_sum"))
g.apply_edges(lambda edges: {"e": edges.data["e"] / edges.dst["e_sum"]})
Expand All @@ -75,6 +75,7 @@ def forward(self, g, h, e):
t = torch.tensor([0.0, self.t], device=h.device, dtype=h.dtype)
x = torch.cat([h.flatten(), g.edata["e"].flatten()])
x = self.integrator(self.odefunc, x, t, method="dopri5")[-1]
# x = self.integrator(self.odefunc, x, t, method="rk4", options={"step_size": 0.1})[-1]
h, e = x[:h.numel()], x[h.numel():]
h = h.reshape(*node_shape)
if parallel:
Expand All @@ -87,7 +88,6 @@ def __init__(
self,
in_features,
out_features,
edge_features=0,
activation=torch.nn.SiLU(),
idx=0,
num_heads=4,
Expand Down Expand Up @@ -158,16 +158,6 @@ def guide(self, g, h):
dgl.function.u_dot_v("k", "log_sigma", "log_sigma")
)

if self.edge_features > 0:
mu_e = self.fc_mu_e(he)
log_sigma_e = self.fc_log_sigma_e(he)
mu_e = mu_e.reshape(*mu_e.shape[:-1], self.num_heads, -1)
log_sigma_e = log_sigma_e.reshape(
*log_sigma_e.shape[:-1], self.num_heads, -1
)
g.edata["mu"] = g.edata["mu"] + mu_e
g.edata["log_sigma"] = g.edata["log_sigma"] + log_sigma_e

mu = g.edata["mu"]
log_sigma = g.edata["log_sigma"]

Expand Down Expand Up @@ -198,7 +188,7 @@ def forward(self, g, h, he=None):
if self.node_prior:
if self.norm:
h = self.norm(h)
h = self.dropout(h)
# h = self.dropout(h)
mu, log_sigma = self.fc_mu_prior(h), self.fc_log_sigma_prior(h)
src, dst = g.edges()
mu = mu[..., dst, :]
Expand Down
2 changes: 1 addition & 1 deletion bronx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, g, h, y=None, mask=None):
obs=y,
)

return h
return h


class GraphRegressionBronxModel(BronxModel):
Expand Down
69 changes: 33 additions & 36 deletions scripts/graph_regression/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,93 +70,90 @@ def run(args):
data_valid, batch_size=len(data_valid),
)

optimizer = SWA(
scheduler = pyro.optim.ReduceLROnPlateau(
{
"base": getattr(torch.optim, args.optimizer),
"base_args": {
"lr": args.learning_rate / batch_size,
"optimizer": getattr(torch.optim, args.optimizer),
"optim_args": {
"lr": args.learning_rate,
"weight_decay": args.weight_decay
},
"swa_args": {
"swa_start": args.swa_start,
"swa_freq": args.swa_freq,
"swa_lr": args.swa_lr,
},
}
"factor": args.lr_factor,
"patience": args.patience,
"mode": "min",
},
)

svi = pyro.infer.SVI(
model,
model.guide,
optimizer,
scheduler,
loss=pyro.infer.TraceMeanField_ELBO(
num_particles=args.num_particles, vectorize_particles=True
),
)

import tqdm
for idx in tqdm.tqdm(range(args.n_epochs)):
for idx in range(args.n_epochs):
for _, g, y in data_train:
if torch.cuda.is_available():
g = g.to("cuda:0")
y = y.to("cuda:0")
model.train()
loss = svi.step(g, g.ndata["h0"], y)

_, g, y = next(iter(data_valid))
if torch.cuda.is_available():
g = g.to("cuda:0")
y = y.to("cuda:0")
_, g, y = next(iter(data_valid))
if torch.cuda.is_available():
g = g.to("cuda:0")
y = y.to("cuda:0")

model.eval()
with torch.no_grad():

model.eval()
swap_swa_sgd(svi.optim)
with torch.no_grad():
predictive = pyro.infer.Predictive(
model,
guide=model.guide,
num_samples=args.num_samples,
parallel=True,
return_sites=["_RETURN"],
)

predictive = pyro.infer.Predictive(
model,
guide=model.guide,
num_samples=args.num_samples,
parallel=True,
return_sites=["_RETURN"],
)
y_hat = predictive(g, g.ndata["h0"])["_RETURN"].mean(0)
rmse = float(((y_hat - y) ** 2).mean() ** 0.5)
scheduler.step(rmse)
print(rmse)

y_hat = predictive(g, g.ndata["h0"])["_RETURN"].mean(0)
rmse = float(((y_hat - y) ** 2).mean() ** 0.5)
print("RMSE: %.6f" % rmse, flush=True)
return rmse

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="ESOL")
parser.add_argument("--batch_size", type=int, default=-1)
parser.add_argument("--hidden_features", type=int, default=25)
parser.add_argument("--hidden_features", type=int, default=100)
parser.add_argument("--embedding_features", type=int, default=20)
parser.add_argument("--activation", type=str, default="SiLU")
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--depth", type=int, default=1)
parser.add_argument("--num_samples", type=int, default=64)
parser.add_argument("--num_particles", type=int, default=4)
parser.add_argument("--num_heads", type=int, default=5)
parser.add_argument("--sigma_factor", type=float, default=2.0)
parser.add_argument("--t", type=float, default=1.0)
parser.add_argument("--optimizer", type=str, default="AdamW")
parser.add_argument("--kl_scale", type=float, default=1e-5)
parser.add_argument("--n_epochs", type=int, default=100)
parser.add_argument("--n_epochs", type=int, default=1000)
parser.add_argument("--adjoint", type=int, default=0)
parser.add_argument("--physique", type=int, default=0)
parser.add_argument("--gamma", type=float, default=1.0)
parser.add_argument("--readout_depth", type=int, default=1)
parser.add_argument("--swa_start", type=int, default=20)
parser.add_argument("--swa_freq", type=int, default=10)
parser.add_argument("--swa_lr", type=float, default=1e-2)
parser.add_argument("--dropout_in", type=float, default=0.0)
parser.add_argument("--dropout_out", type=float, default=0.0)
parser.add_argument("--norm", type=int, default=1)
parser.add_argument("--subsample_size", type=int, default=100)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--checkpoint", type=str, default="")
parser.add_argument("--seed", type=int, default=2666)
parser.add_argument("--lr_factor", type=float, default=0.5)
parser.add_argument("--patience", type=int, default=10)
args = parser.parse_args()
run(args)
12 changes: 11 additions & 1 deletion scripts/node_classification/check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import glob
import json
from re import S
import pandas as pd
import torch
import pyro
import dgl
from types import SimpleNamespace

def check(args):
results = []
Expand Down Expand Up @@ -34,14 +36,22 @@ def check(args):


if args.rerun:
from run import run
config = results[0]["config"]
config["split_index"] = -1
config["lr_factor"] = 0.5
config["patience"] = 10
config = SimpleNamespace(**config)
accuracy_vl, accuracy_te = run(config)

if args.reevaluate:
if torch.cuda.is_available():
model = torch.load(results[0]["config"]["checkpoint"])
g = g.to("cuda:0")
else:
model = torch.load(results[0]["config"]["checkpoint"], map_location="cpu")
model.eval()


with torch.no_grad():
predictive = pyro.infer.Predictive(
model,
Expand Down
66 changes: 66 additions & 0 deletions scripts/node_classification/check_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import glob
import json
import pandas as pd
import torch
import pyro
import dgl

def check(args):
results = []
result_paths = glob.glob(args.path + "/*/*/result.json")
for result_path in result_paths:
try:
with open(result_path, "r") as f:
result_str = f.read()
result = json.loads(result_str)
results.append(result)
except:
pass

results = sorted(results, key=lambda x: x["_metric"]["accuracy"], reverse=True)

from run import get_graph
g = get_graph(results[0]["config"]["data"])
y = g.ndata["label"].argmax(-1)

ys_hat = []
with torch.no_grad():
for idx in range(args.first):
if torch.cuda.is_available():
model = torch.load(results[idx]["config"]["checkpoint"])
g = g.to("cuda:0")
else:
model = torch.load(results[idx]["config"]["checkpoint"], map_location="cpu")
model.eval()

predictive = pyro.infer.Predictive(
model,
guide=model.guide,
num_samples=args.num_samples,
parallel=True,
return_sites=["_RETURN"],
)

y_hat = predictive(g, g.ndata["feat"])["_RETURN"]
ys_hat.append(y_hat)

g = g.to("cpu")
y_hat = torch.cat(ys_hat, dim=0).mean(0).argmax(-1).cpu()

print(y_hat.shape, y.shape)
print(y_hat)
print(y)

accuracy_vl = (y_hat[g.ndata["val_mask"]] == y[g.ndata["val_mask"]]).float().mean()
accuracy_te = (y_hat[g.ndata["test_mask"]] == y[g.ndata["test_mask"]]).float().mean()
print(accuracy_vl, accuracy_te)

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default=".")
parser.add_argument("--first", type=int, default=16)
parser.add_argument("--num_samples", type=int, default=32)
args = parser.parse_args()
check(args)
Loading