diff --git a/.gitignore b/.gitignore index 8514af6..9f5967d 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,7 @@ dmypy.json *.txt events* *.pt -core.* \ No newline at end of file +core.* + +*dataset* +*.bin \ No newline at end of file diff --git a/bronx/layers.py b/bronx/layers.py index b5bebad..8f31ff4 100644 --- a/bronx/layers.py +++ b/bronx/layers.py @@ -29,9 +29,10 @@ def forward(self, t, x): h0 = h g = self.g.local_var() g.edata["e"] = e - g.ndata["h"] = h + g.ndata["h"] = h g.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h")) h = g.ndata["h"] + # h = h / g.in_degrees().float().clamp(min=1).view(-1, *((1,) * (h.dim()-1))) # h = h.tanh() h = h - h0 * self.gamma if self.h0 is not None: @@ -62,7 +63,6 @@ def forward(self, g, h, e): e, h = e.swapaxes(0, 1), h.swapaxes(0, 1) h = h.reshape(*h.shape[:-1], e.shape[-2], -1) - # e = edge_softmax(g, e) g.edata["e"] = e g.update_all(fn.copy_e("e", "m"), fn.sum("m", "e_sum")) g.apply_edges(lambda edges: {"e": edges.data["e"] / edges.dst["e_sum"]}) @@ -75,6 +75,7 @@ def forward(self, g, h, e): t = torch.tensor([0.0, self.t], device=h.device, dtype=h.dtype) x = torch.cat([h.flatten(), g.edata["e"].flatten()]) x = self.integrator(self.odefunc, x, t, method="dopri5")[-1] + # x = self.integrator(self.odefunc, x, t, method="rk4", options={"step_size": 0.1})[-1] h, e = x[:h.numel()], x[h.numel():] h = h.reshape(*node_shape) if parallel: @@ -87,7 +88,6 @@ def __init__( self, in_features, out_features, - edge_features=0, activation=torch.nn.SiLU(), idx=0, num_heads=4, @@ -158,16 +158,6 @@ def guide(self, g, h): dgl.function.u_dot_v("k", "log_sigma", "log_sigma") ) - if self.edge_features > 0: - mu_e = self.fc_mu_e(he) - log_sigma_e = self.fc_log_sigma_e(he) - mu_e = mu_e.reshape(*mu_e.shape[:-1], self.num_heads, -1) - log_sigma_e = log_sigma_e.reshape( - *log_sigma_e.shape[:-1], self.num_heads, -1 - ) - g.edata["mu"] = g.edata["mu"] + mu_e - g.edata["log_sigma"] = g.edata["log_sigma"] + log_sigma_e - mu = g.edata["mu"] log_sigma = g.edata["log_sigma"] @@ -198,7 +188,7 @@ def forward(self, g, h, he=None): if self.node_prior: if self.norm: h = self.norm(h) - h = self.dropout(h) + # h = self.dropout(h) mu, log_sigma = self.fc_mu_prior(h), self.fc_log_sigma_prior(h) src, dst = g.edges() mu = mu[..., dst, :] diff --git a/bronx/models.py b/bronx/models.py index a065902..d7446a5 100644 --- a/bronx/models.py +++ b/bronx/models.py @@ -141,7 +141,7 @@ def forward(self, g, h, y=None, mask=None): obs=y, ) - return h + return h class GraphRegressionBronxModel(BronxModel): diff --git a/scripts/graph_regression/run.py b/scripts/graph_regression/run.py index af17f01..c51e7b3 100644 --- a/scripts/graph_regression/run.py +++ b/scripts/graph_regression/run.py @@ -70,32 +70,29 @@ def run(args): data_valid, batch_size=len(data_valid), ) - optimizer = SWA( + scheduler = pyro.optim.ReduceLROnPlateau( { - "base": getattr(torch.optim, args.optimizer), - "base_args": { - "lr": args.learning_rate / batch_size, + "optimizer": getattr(torch.optim, args.optimizer), + "optim_args": { + "lr": args.learning_rate, "weight_decay": args.weight_decay }, - "swa_args": { - "swa_start": args.swa_start, - "swa_freq": args.swa_freq, - "swa_lr": args.swa_lr, - }, - } + "factor": args.lr_factor, + "patience": args.patience, + "mode": "min", + }, ) svi = pyro.infer.SVI( model, model.guide, - optimizer, + scheduler, loss=pyro.infer.TraceMeanField_ELBO( num_particles=args.num_particles, vectorize_particles=True ), ) - import tqdm - for idx in tqdm.tqdm(range(args.n_epochs)): + for idx in range(args.n_epochs): for _, g, y in data_train: if torch.cuda.is_available(): g = g.to("cuda:0") @@ -103,26 +100,27 @@ def run(args): model.train() loss = svi.step(g, g.ndata["h0"], y) - _, g, y = next(iter(data_valid)) - if torch.cuda.is_available(): - g = g.to("cuda:0") - y = y.to("cuda:0") + _, g, y = next(iter(data_valid)) + if torch.cuda.is_available(): + g = g.to("cuda:0") + y = y.to("cuda:0") + + model.eval() + with torch.no_grad(): - model.eval() - swap_swa_sgd(svi.optim) - with torch.no_grad(): + predictive = pyro.infer.Predictive( + model, + guide=model.guide, + num_samples=args.num_samples, + parallel=True, + return_sites=["_RETURN"], + ) - predictive = pyro.infer.Predictive( - model, - guide=model.guide, - num_samples=args.num_samples, - parallel=True, - return_sites=["_RETURN"], - ) + y_hat = predictive(g, g.ndata["h0"])["_RETURN"].mean(0) + rmse = float(((y_hat - y) ** 2).mean() ** 0.5) + scheduler.step(rmse) + print(rmse) - y_hat = predictive(g, g.ndata["h0"])["_RETURN"].mean(0) - rmse = float(((y_hat - y) ** 2).mean() ** 0.5) - print("RMSE: %.6f" % rmse, flush=True) return rmse if __name__ == "__main__": @@ -130,12 +128,12 @@ def run(args): parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default="ESOL") parser.add_argument("--batch_size", type=int, default=-1) - parser.add_argument("--hidden_features", type=int, default=25) + parser.add_argument("--hidden_features", type=int, default=100) parser.add_argument("--embedding_features", type=int, default=20) parser.add_argument("--activation", type=str, default="SiLU") parser.add_argument("--learning_rate", type=float, default=1e-3) parser.add_argument("--weight_decay", type=float, default=1e-5) - parser.add_argument("--depth", type=int, default=3) + parser.add_argument("--depth", type=int, default=1) parser.add_argument("--num_samples", type=int, default=64) parser.add_argument("--num_particles", type=int, default=4) parser.add_argument("--num_heads", type=int, default=5) @@ -143,14 +141,11 @@ def run(args): parser.add_argument("--t", type=float, default=1.0) parser.add_argument("--optimizer", type=str, default="AdamW") parser.add_argument("--kl_scale", type=float, default=1e-5) - parser.add_argument("--n_epochs", type=int, default=100) + parser.add_argument("--n_epochs", type=int, default=1000) parser.add_argument("--adjoint", type=int, default=0) parser.add_argument("--physique", type=int, default=0) parser.add_argument("--gamma", type=float, default=1.0) parser.add_argument("--readout_depth", type=int, default=1) - parser.add_argument("--swa_start", type=int, default=20) - parser.add_argument("--swa_freq", type=int, default=10) - parser.add_argument("--swa_lr", type=float, default=1e-2) parser.add_argument("--dropout_in", type=float, default=0.0) parser.add_argument("--dropout_out", type=float, default=0.0) parser.add_argument("--norm", type=int, default=1) @@ -158,5 +153,7 @@ def run(args): parser.add_argument("--k", type=int, default=0) parser.add_argument("--checkpoint", type=str, default="") parser.add_argument("--seed", type=int, default=2666) + parser.add_argument("--lr_factor", type=float, default=0.5) + parser.add_argument("--patience", type=int, default=10) args = parser.parse_args() run(args) diff --git a/scripts/node_classification/check.py b/scripts/node_classification/check.py index d3e5d4d..bba4212 100644 --- a/scripts/node_classification/check.py +++ b/scripts/node_classification/check.py @@ -1,10 +1,12 @@ import os import glob import json +from re import S import pandas as pd import torch import pyro import dgl +from types import SimpleNamespace def check(args): results = [] @@ -34,6 +36,15 @@ def check(args): if args.rerun: + from run import run + config = results[0]["config"] + config["split_index"] = -1 + config["lr_factor"] = 0.5 + config["patience"] = 10 + config = SimpleNamespace(**config) + accuracy_vl, accuracy_te = run(config) + + if args.reevaluate: if torch.cuda.is_available(): model = torch.load(results[0]["config"]["checkpoint"]) g = g.to("cuda:0") @@ -41,7 +52,6 @@ def check(args): model = torch.load(results[0]["config"]["checkpoint"], map_location="cpu") model.eval() - with torch.no_grad(): predictive = pyro.infer.Predictive( model, diff --git a/scripts/node_classification/check_multi.py b/scripts/node_classification/check_multi.py new file mode 100644 index 0000000..3e42cfa --- /dev/null +++ b/scripts/node_classification/check_multi.py @@ -0,0 +1,66 @@ +import os +import glob +import json +import pandas as pd +import torch +import pyro +import dgl + +def check(args): + results = [] + result_paths = glob.glob(args.path + "/*/*/result.json") + for result_path in result_paths: + try: + with open(result_path, "r") as f: + result_str = f.read() + result = json.loads(result_str) + results.append(result) + except: + pass + + results = sorted(results, key=lambda x: x["_metric"]["accuracy"], reverse=True) + + from run import get_graph + g = get_graph(results[0]["config"]["data"]) + y = g.ndata["label"].argmax(-1) + + ys_hat = [] + with torch.no_grad(): + for idx in range(args.first): + if torch.cuda.is_available(): + model = torch.load(results[idx]["config"]["checkpoint"]) + g = g.to("cuda:0") + else: + model = torch.load(results[idx]["config"]["checkpoint"], map_location="cpu") + model.eval() + + predictive = pyro.infer.Predictive( + model, + guide=model.guide, + num_samples=args.num_samples, + parallel=True, + return_sites=["_RETURN"], + ) + + y_hat = predictive(g, g.ndata["feat"])["_RETURN"] + ys_hat.append(y_hat) + + g = g.to("cpu") + y_hat = torch.cat(ys_hat, dim=0).mean(0).argmax(-1).cpu() + + print(y_hat.shape, y.shape) + print(y_hat) + print(y) + + accuracy_vl = (y_hat[g.ndata["val_mask"]] == y[g.ndata["val_mask"]]).float().mean() + accuracy_te = (y_hat[g.ndata["test_mask"]] == y[g.ndata["test_mask"]]).float().mean() + print(accuracy_vl, accuracy_te) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, default=".") + parser.add_argument("--first", type=int, default=16) + parser.add_argument("--num_samples", type=int, default=32) + args = parser.parse_args() + check(args) diff --git a/scripts/node_classification/plot.ipynb b/scripts/node_classification/plot.ipynb new file mode 100644 index 0000000..f71f284 --- /dev/null +++ b/scripts/node_classification/plot.ipynb @@ -0,0 +1,280 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "! export PYTHONPATH=$PYTHONPATH:/data/chodera/wangyq/bronx/" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('/data/chodera/wangyq/bronx/')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import bronx\n", + "import torch\n", + "import dgl\n", + "import pyro\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "model = torch.load(\"/data/chodera/wangyq/node_classification/best.pt\", map_location=torch.device('cpu'))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "NodeClassificationBronxModel(\n", + " (fc_in): Sequential(\n", + " (0): Linear(in_features=1433, out_features=42, bias=False)\n", + " )\n", + " (fc_out): Sequential(\n", + " (0): ELU(alpha=1.0)\n", + " (1): Dropout(p=0.5303411489006438, inplace=False)\n", + " (2): Linear(in_features=42, out_features=7, bias=False)\n", + " )\n", + " (activation): ELU(alpha=1.0)\n", + " (layer0): BronxLayer(\n", + " (fc_mu): Linear(in_features=42, out_features=147, bias=False)\n", + " (fc_log_sigma): Linear(in_features=42, out_features=147, bias=False)\n", + " (fc_k): Linear(in_features=42, out_features=147, bias=False)\n", + " (fc_mu_prior): Linear(in_features=42, out_features=21, bias=False)\n", + " (fc_log_sigma_prior): Linear(in_features=42, out_features=21, bias=False)\n", + " (activation): ELU(alpha=1.0)\n", + " (linear_diffusion): LinearDiffusion(\n", + " (odefunc): ODEFunc()\n", + " )\n", + " (norm): LayerNorm((42,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.6939372020133809, inplace=False)\n", + " )\n", + " (consistency_regularizer): ConsistencyRegularizer()\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def graph_forward(self, g, h, y=None, mask=None):\n", + " h = self.fc_in(h)\n", + " for idx in range(self.depth):\n", + " h = getattr(self, f\"layer{idx}\")(g, h)\n", + " return g" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "from run import get_graph\n", + "g = get_graph(\"CoraGraphDataset\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "g = graph_forward(model, g, g.ndata[\"feat\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "param_store = pyro.get_param_store()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "predictive = pyro.infer.Predictive(\n", + " model,\n", + " guide=model.guide,\n", + " num_samples=4,\n", + " parallel=True,\n", + ")\n", + "\n", + "results = predictive(g, g.ndata[\"feat\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 10556, 21, 1])" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results[\"e0\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "import seaborn as sns\n", + "e = results[\"e0\"].flatten().detach().numpy()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "h = g.ndata[\"feat\"]\n", + "h = model.fc_in(h)\n", + "h = model.layer0.norm(h)\n", + "mu_prior, log_sigma_prior = model.layer0.fc_mu_prior(h), model.layer0.fc_log_sigma_prior(h)\n", + "src, dst = g.edges()\n", + "mu_prior = mu_prior[..., dst, :]\n", + "log_sigma_prior = log_sigma_prior[..., dst, :]\n", + "mu_prior, log_sigma_prior = mu_prior.unsqueeze(-1), log_sigma_prior.unsqueeze(-1)\n", + "sigma_prior = log_sigma_prior.exp() * model.layer0.sigma_factor" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "from math import log\n", + "mu_posterior, log_sigma_posterior, k_posterior = model.layer0.fc_mu(h), model.layer0.fc_log_sigma(h), model.layer0.fc_k(h)\n", + "mu_posterior = mu_posterior.reshape(*mu_posterior.shape[:-1], model.layer0.num_heads, -1)\n", + "log_sigma_posterior = log_sigma_posterior.reshape(\n", + " *log_sigma_posterior.shape[:-1], model.layer0.num_heads, -1\n", + ")\n", + "k_posterior = k_posterior.reshape(*k_posterior.shape[:-1], model.layer0.num_heads, -1)\n", + "\n", + "g.ndata[\"mu_posterior\"], g.ndata[\"log_sigma_posterior\"], g.ndata[\"k_posterior\"] = mu_posterior, log_sigma_posterior, k_posterior\n", + "g.apply_edges(dgl.function.u_dot_v(\"k_posterior\", \"mu_posterior\", \"mu_posterior\"))\n", + "g.apply_edges(\n", + " dgl.function.u_dot_v(\"k_posterior\", \"log_sigma_posterior\", \"log_sigma_posterior\")\n", + ")\n", + "\n", + "mu_posterior = g.edata[\"mu_posterior\"]\n", + "log_sigma_posterior = g.edata[\"log_sigma_posterior\"] \n", + "sigma_posterior = log_sigma_posterior.exp() * model.layer0.sigma_factor" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.0, 1.0)" + ] + }, + "execution_count": 103, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADFCAYAAABNTP5kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc20lEQVR4nO3de1RTZ7438G/uEEgiCCQgNyv1rqBWLW2naoepox2VoV0znfUexc44ta3T9XpYzozMWkdX29NjZ83Uo516PF1vR1m+vTvjZc70Qh1byznWSxXwVm+gAnILFwkJhBCS5/yxswNUVBJ28iTh91krSwibvX9u9jfPs599kzHGGAghXMh5F0DIaEYBJIQjCiAhHFEACeGIAkgIRxRAQjiiABLCkZJ3ASPhdrvR0NAAnU4HmUzGuxwSwRhjsFqtSElJgVwuXbsV1gFsaGhAWloa7zLIKFJXV4fU1FTJ5hfWAdTpdACElaLX6zlXQyJZZ2cn0tLSvNucVMI6gGK3U6/XUwBJUEi9q0ODMIRwRAEkhCMKICEcRUQAy6604Fh1G+8yCPFZWA/CiF54txyq6Bgc2bAQafFa3uWQCPK3Mw34W2U9npo5NiDzj4gWEABcboa/nWngXQaJIE6XG5sPnsc/Lprx7J7TAVlGxAQQAA5W1vMugUSQsistuNXtDOgyIiKA+TkpkMmAK802tNocvMshEeK/PD0qpTxwpzlGRAAfm2JEhmff70qzlXM1JFKcvWkBAGxcMjlgy4iIAE5L1mOiUThF6EoTBZCMXI/ThRttXQCA5dkpGBMdmPHKiAhgkl7jDeDlZhvnakgkqDLb4GZAfIwaiToNJiRKew6oKCICKJPJMNHkaQGpC0okcNnTk5pojIVMJkNqfHRAlhMRAQSASQO6oHSrUzJS4ge5uF2lxQXm+HLEBDBjrLCCrI4+WOyBHTomka/KLOzKZIkBpBbw7qJUCiTEqgEAN2/ZOVdDwl19h7ANpcUJwQvUGVYRE0AASBkjrCxx5RHir3rPh3iqGMBI7ILu3LkTM2fO9F5Qm5ubi08//dTv+Y3zBLCBAkhGwGJ3wuroA9D/oT5GqwrIsrgGMDU1Fa+99hpOnz6NU6dO4bHHHsOKFStw4cIFv+YnBrCeuqBkBMTtJz5GDa1aOP4XqJt+cb0aYtmyZYO+f/XVV7Fz504cP34c06ZN83l+4+KoC0pGTtx+xA/0QAqZy5FcLhf27t2Lrq4u5ObmDjmNw+GAw9F/rmdnZ+egn9M+IJFC/a1uAMEJIPdBmHPnziE2NhYajQbPPfcc9u/fj6lTpw457ZYtW2AwGLyv796SkLqgRAreFjBuFARw0qRJqKysxIkTJ/D888+jsLAQ33777ZDTFhcXw2KxeF91dXWDfm4yRAEA2rt74XS5A147iUyNlh4AQLJnewok7l1QtVqNrKwsAMCcOXPwzTffYPv27Xjrrbdum1aj0UCj0dxxXnFaNRRyGVxuhvauXhj1gV+BJPKYO4XdHFMQAsi9Bfwut9s9aD/PFwq5DGNjhIPxLVa6LpD4x2wVWsAkXYS3gMXFxViyZAnS09NhtVrx3nvv4ciRIygtLfV7nok6DcxWBwWQ+IUxhmZPC2jU37m3JRWuATSbzVi1ahUaGxthMBgwc+ZMlJaW4gc/+IHf80yIFVYaBZD4w+bog93pAjAKWsA///nPks8zUecJIN2agvhBbP10GiWi1YqALy/k9gFHyhtAagGJH7z7f0HofgKRGMBYagGJ/8QR0GB0P4FIDCC1gGQExBYwGAMwQAQHsJUCSPwg7gMmBekYcsQGkFpA4g/xvrLirkygRWwArY4+2HtdnKsh4abN1gsASNCpg7K8iAugTqOERin8t+gu2cRX4jYzNoZaQL/IZDJvK2imbijxkRjABOqC+o/2A4k/xJP4AXhv8BVokRlAOhZI/HCruxduzy1l42MogH6jFpD4QxyAidOqoFQEJxoUQEI8gr3/B1AACfHyjoAGaf8PiNQA0j4g8UOreAyQWsCRSaDT0YgfqAsqEbEFbLU56ElJZNjavAGkLuiIiH14R58bXXQ6Ghkm6oJKRKtWQuu5mpm6oWS42ryDMBTAEUsY0A0lZDj6W0Dqgo6Y2A0VVyohd8MYo0EYKVELSHxhc/TB0SfcTZ0CKAEKIPGFeBpajFoRlLuhiSI4gEIXtI26oGQYWjkMwAARHUBqAcnwtXI4Bgj4GcBr165JXYfkxABSC0iGQxysC4sWMCsrC4sWLcI777yDnp4eqWuSRP8oKLWA5N54jIACfgawvLwcM2fORFFREUwmE9auXYuTJ09KXduIJNAJ2cQHbRyOAQJ+BjAnJwfbt29HQ0MDdu3ahcbGRjzyyCOYPn06tm7dipaWFqnr9Jm4Iq09fXD00elo5O7CqgUUKZVKFBQUYO/evfj973+PqqoqbNiwAWlpad6nHvFiiFZBpZABoP1Acm9tHM4DBUYYwFOnTuGFF15AcnIytm7dig0bNqC6uhqHDh1CQ0MDVqxYIVWdPpPJZN5by9F+ILkXHhfjAn4+nmzr1q3YvXs3Ll++jKVLl2LPnj1YunQp5HIhz+PHj0dJSQkyMzOlrNVnY2PVaOrsoRaQ3FMLpy6oXwHcuXMnfv7zn2P16tVITk4ecpqkpKSAPP/PFzQQQ4bD0eeCtacPQPAHYfwK4KFDh5Cenu5t8USMMdTV1SE9PR1qtRqFhYWSFOkvOhhPhkPsISnlMhiiVUFdtl/7gBMmTEBra+tt77e3t2P8+PHDns+WLVswd+5c6HQ6JCUlIT8/H5cvX/anpCHR6WhkONq8B+HVkMlkQV22XwG8020ebDYboqKG/1inr776CuvWrcPx48dx6NAhOJ1OPP744+jq6vKnrNtQC0iGg9chCMDHLmhRUREAYYRx06ZN0Gq13p+5XC6cOHECOTk5w57fZ599Nuj7kpISJCUl4fTp03j00Ud9KW1IY6kFJMPA60RswMcAVlRUABBawHPnzkGt7t9hVavVyM7OxoYNG/wuxmKxAADi4+OH/LnD4YDD0d+adXZ23nV+1AKS4WjhdCI24GMAv/zySwDAM888g+3bt0Ov10tWiNvtxvr16/Hwww9j+vTpQ06zZcsWvPTSS8OeJwWQDEewnws/kF/7gLt375Y0fACwbt06nD9/Hh988MEdpykuLobFYvG+6urq7jpP8ROtvasXLjfdnpAMTbyDepIuhLugBQUFKCkpgV6vR0FBwV2n3bdvn09F/OpXv8Lf//53lJWVITU19Y7TaTQaaDTDX0nxMWrIZICbCU++4bGTTUKf2Spc0ZOkD+EAGgwG7xCtwWCQZOGMMbz44ovYv38/jhw54tMhjOFQKuSI06rR3tWLVpuDAkiGZLby64IOO4C7d+8e8uuRWLduHd577z0cPHgQOp0OTU1NAISAR0dHS7KMsTFCAGkklAyFMTZgHzD4H9B+7QPa7XZ0d3d7v6+pqcG2bdvw+eef+zSfnTt3wmKxYOHChUhOTva+PvzwQ3/KGhINxJC7sTn6YHcKl6uFdBd0oBUrVqCgoADPPfccOjo6MG/ePKjVarS2tmLr1q14/vnnhzWfYDy3QTwWSI8qI0MRu5+xGiW0ar/iMCJ+XxH/ve99DwDwl7/8BSaTCTU1NdizZw/eeOMNSQscKe+9YbqoC0pux7P7CfgZwO7ubuh0OgDA559/joKCAsjlcjz44IOoqamRtMCRSqRHlZG7EEdAE8MpgFlZWThw4ADq6upQWlqKxx9/HABgNpslPz44UuKxQLokiQzF2wLqgz8CCvgZwE2bNmHDhg3IzMzE/PnzkZubC0BoDWfNmiVpgSMlrlhxRRMykPcYIKcW0K+9zqeeegqPPPIIGhsbkZ2d7X3/+9//Pn784x9LVpwUjJ5jO+KKJmQgM8ezYAA/AwgAJpMJJpNp0Hvz5s0bcUFSE4eWW229cLrcUCki9mbgxA/9XdAwCmBXVxdee+01HD58GGazGW63e9DPQ+nO2fFaNZRyGfrcDC1WB1LGSHOAn0SG/i4on31AvwK4Zs0afPXVV1i5ciWSk5ODfhWxL+RyGZJ0GjRYetDc2UMBJIOEZRf0008/xccff4yHH35Y6noCIkkf5QkgDcSQfj3O/psx8WoB/dohiouLu+NFs6HI6Onf00AMGUjc/1Mr5dBHB/8sGMDPAL7yyivYtGnToPNBQ5nRcyiiuZMCSPoNPATBazfKr9i//vrrqK6uhtFoRGZmJlSqwbdyKy8vl6Q4qYgBbLJQF5T0473/B/gZwPz8fInLCCyTGMBOO+dKSChp6BC2h2SOA3N+BXDz5s1S1xFQyWOEADZ2UBeU9GvwbA/jOAbQ76PSHR0dePvtt1FcXIz29nYAQtezvr5esuKkIq7g+g57UC6BIuGh0SK0gCkGPiOggJ8t4NmzZ5GXlweDwYAbN27gl7/8JeLj47Fv3z7U1tZiz549Utc5IibPCnb0uXGr24n4mODffo6EHrELyvPYsF8tYFFREVavXo2rV68OuhP20qVLUVZWJllxUtEoFd7rAsWVTki9pwsadgH85ptvsHbt2tveHzdunPe+LqEmxbMfSAEkgPBEJPE2JWEXQI1GM+Rdqa9cuYLExMQRFxUIKQZhJVMACQA0WYTWL0olR5w2uE9EGsivAC5fvhwvv/wynE4nAOFZEbW1tfjtb3+LJ598UtICpSKOhDZYaCSUCANygPDBzPNcZr8C+Prrr8NmsyExMRF2ux0LFixAVlYWdDodXn31ValrlERqnPAgmfpb1AISoK5dOIsrNV57jykDy69RUIPBgEOHDuHo0aM4c+YMbDYbZs+ejby8PKnrk0y6Z0XXtEvz6DMS3mo9AcwItwC63W6UlJRg3759uHHjBmQyGcaPHw+TyQTGWMhemiQGsLYtPM5fJYFV49kO0jkH0KcuKGMMy5cvx5o1a1BfX48ZM2Zg2rRpqKmpwerVq0PudhQDiSu6s6cPlm4n52oIb2ILmD42jFrAkpISlJWV4fDhw1i0aNGgn33xxRfIz8/Hnj17sGrVKkmLlEK0WoFEnQYtVgdq27sxQyvN8y1IePJ2QTkH0KcW8P3338fvfve728IHAI899hg2btyId999V7LipObthrZTN3Q0s9id6PD0gtLiwiiAZ8+exQ9/+MM7/nzJkiU4c+bMiIsKlAwaiCHoHwdIiNUgRsPnQlyRTwFsb2+H0Wi848+NRiNu3bo14qICJWNsDADgRisFcDS71moDANyXEMO5Eh8D6HK5oFTe+RNDoVCgr69vxEUFyoQkYYVXmW2cKyE8XW0W/v4TkmI5V+LjIAxjDKtXr77jU2odjtC+4nxCorDCq8y2kD5kQgJL/ADOCrcAFhYW3nOaUBwBFY1PiIFcJhyKaLE5uN0Ji/BV1RKmAZTqybi8RKkUSIvXoqatG1VmGwVwFHK63N4xgFAI4Ki7T7vYDa2m/cBRqaatC31uBq1agWROT0QaiGsAy8rKsGzZMqSkpEAmk+HAgQMBX+b9RiGAl5qsAV8WCT0XGoTL6CaZdJDL+Y8BcA1gV1cXsrOzsWPHjqAtc1qKcAaM+Icgo4v4d5+eEhpnQnE9CrlkyRIsWbIkqMucliI8QPRSUydcbgZFCHwKkuC50GAB0L8d8Mb3NAAfORyOQYc6hroq/17Gj42BVq1Ad68L11psuN+ok7JEEsIYYzhf72kBx4VGCxhWgzBbtmyBwWDwvtLS0nyeh1wuw9Rk4dOPuqGjy81bdljsTqgUMu9YAG9hFcDi4mJYLBbvq66uzq/5zEgVPv0qakP3tDkivVM1wv1rp6UYoFEqOFcjCKsAajQa6PX6QS9/zM0Unuz0zQ0K4Ghy8rrw956bGce5kn5hFUCpPJAh/AEuNXWis4cuzh0tTt0QWkDxAzgUcA2gzWZDZWUlKisrAQDXr19HZWUlamtrA7rcJH0U0uO1cDOgvIZawdGgzebAVc/JFw9QAAWnTp3CrFmzMGvWLADCHbdnzZqFTZs2BXzZ88cLf4SjVa0BXxbhr+xqCwBgarI+pB5NwDWACxcuBGPstldJSUnglz0pCQDw5eWWgC+L8PflJeHvvGhyaN04elTuAwLAI/cnQCGXocps894jkkSmPpfb2wIu8nzwhopRG0BDtApzPIMxpRdC83kWRBpHq9vQ4XkqVk7aGN7lDDJqAwgAP5qZDAD425kGzpWQQDpYITyz8kczk6FUhNYmH1rVBNkTM5KhkMtw9qYF1S10eVIksvY48Zmnh7MiJ4VzNbcb1QEcG6vBoknCTvn/P1bDuRoSCHtP3UR3rwtZSbGYnR46B+BFozqAALAqNxMAsPdUHR2UjzB9LjdKvr4BAFj9UGZI3gNo1Afwe/cn4P6kWHT1uvDn/77Ouxwiob+W30RtezfiY9QomD2OdzlDGvUBlMlkWJ83EQDw9n9fg9lKzw+MBF2OPmz7x1UAwAsLJ0CrDs0r70Z9AAFgyXQTZqYa0NXrwsv/9S3vcogEtv3jChotPUiLj8Y/PZjBu5w7ogBCuEbw3348A3IZ8PezjThYWc+7JDICx6rb8Pb/CLsTLy+fjihVaFx6NBQKoMf0cQb8alEWAGDjX8+hnK4VDEv1HXa8+H4FGAOenpuGRZND68yX76IADvB/8ybi0YmJsDtdWL3rpPf+ISQ8tFgdKNx1Eq02ByabdNi0bCrvku6JAjiAQi7Df/7TbMzJiENnTx/+z9sn8D9X6WqJcGC29uBn/+84qsw2JBui8HbhAyE78DIQBfA7tGoldq2ei5y0MejodmLVrhPYeaQaLjfjXRq5g3M3Lch/86g3fO//8kGkcn7u33BRAIdgiFbhg2cfxJOzU+FmwO8/u4SC/ziK8/XUJQ0lbjfDO8dr8NR/fo0GSw/GJ8Tgg2cfRGYIPHZsuGSMsbD9aO/s7ITBYIDFYvH7/jB3wxjDR6fq8K8fX4S1R3js2pLpJvzzDyZiIt3OkKsqsw2bDp7H19VtAIBFkxKx7elZMESrArK8QG1rFMBhMHf24N8+uYiDZxrAGCCTAY9PNaIwNxO5E8aG5ClOkarJ0oP/OFKFd0/UwuVmiFLJ8dsfTkZhbmZAbzVPARxCsAIoutxkxb8fuuI9ux4A7k+KxarcDCzPGRewT18CVJmteOurazhQWQ+nS9hk86Yk4V9+NNX75ONAogAOIdgBFF1ttmLPsRr8tVw40x4A1Eo58qYkIT9nHBZOSoJaSbvXI2Vz9OGTs4346FQdTg24eda88fFYn3c/HpqQELRaKIBD4BVA7/J7nPjr6Zt4/2QtrjT3X08Yp1Vh6Yxk5E01Ive+sSF9Jkaoaeiw4/DFZvzjohnHrrWht88NAJDLgLwpRjy3cAKXy4oogEPgHUARYwzfNnbiQEU9DlY2wGztf35FtEqBh7MS8OjEBMzJiMNkk54eCOPhdjNcb+vC6Ru3cLrmFk7VtKO6pWvQNPclxuCpOal4cnYqjByf50cBHEKoBHAgl5vhaFUrSi804YtLZjRaBl9dEaNWICd9DKanGJCVFIv7jTpkJcUiVhP6B4391eN0ob7Djtr2blQ123C52YorzVZcbbbB7nQNmlYuA2anx+H7U4zIm5KErKTYkBjkogAOIRQDOBBjDBcbrfjiUjNOXG9HZW0HrI6+IaeN06qQbIhGypgoJBuiYTJEIU6rxhitCmOiVTBoVdBHqRCtVkCrViBapQiJDdPtZmjv7kVzZw+aO3vQ0NGDm7fsuHmrGzdv2VHfYUfLgB7Bd2mUcmSnjsGczDjMSY/DnIw4xIXQfTtFFMAhhHoAv8vlZrhqtqK8pkNoAcxWXGm23XUDvZtolQIxGoUQSpXSG06tWoFotRIxasWA95SIVglfa1RyiH91778QPjDEr+F5380YunpdsPY4Yevpg7WnDxa7E83WHpg7HTBbe7yjkncTo1YgLV6LCYmxmGjUYZJJaP0z4rUhd6OkoQRqW4vcfk8IUshlmGzSY7Jp8B/QYneiocOORosdDR09aLTY0dzpQEe3ExZ7Lzq6neiwO2HtcaLH6fb+nt3puq0Lx4NMBoyN0cCo1yDZEIXUOC1S46I9L+FrQ7QqJFrsUEMBDAGGaBUM0SpMSb73J6vbzWB3utDd60J3b5/nXxfsnu/tThe6HJ6ve13odvb/TJy2t8+NgVmQyWSQAd73ZJ73AGGfTKtWIjZKCV2UEjqNEvpoFZJ0UTDqNTDqo5Co00AVBq1YKKIAhhm5XIYYjRIxGiUADe9yyAjRxxYhHFEACeGIAkgIRxRAQjiiABLCEQWQEI5CIoA7duxAZmYmoqKiMH/+fJw8eZJ3SYQEBfcAfvjhhygqKsLmzZtRXl6O7OxsLF68GGazmXdphAQc93NB58+fj7lz5+LNN98EALjdbqSlpeHFF1/Exo0bB03rcDjgcPSfN2mxWJCeno66urqwOBeUhK/Ozk6kpaWho6MDBoNBuhkzjhwOB1MoFGz//v2D3l+1ahVbvnz5bdNv3ryZwXPeML3oxeNVXV0taQa4norW2toKl8sFo9E46H2j0YhLly7dNn1xcTGKioq833d0dCAjIwO1tbXSfipFOPHTnHoOwyf2tuLj4yWdb1idC6rRaKDR3H7+o8FgoA3JD3q9ntabj+RyaYdNuA7CJCQkQKFQoLm5edD7zc3NMJlMnKoiJHi4BlCtVmPOnDk4fPiw9z23243Dhw8jNzeXY2WEBAf3LmhRUREKCwvxwAMPYN68edi2bRu6urrwzDPP3PN3NRoNNm/ePGS3lNwZrTffBWqdcT8MAQBvvvkm/vCHP6CpqQk5OTl44403MH/+fN5lERJwIRFAQkYr7mfCEDKaUQAJ4YgCSAhHFEBCOAr5APp6qdLevXsxefJkREVFYcaMGfjkk0+CVGlo8WW9lZSUCLcmHPCKiuL3HAYeysrKsGzZMqSkpEAmk+HAgQP3/J0jR45g9uzZ0Gg0yMrKQklJic/LDekA+nqp0tdff42f/exn+MUvfoGKigrk5+cjPz8f58+fD3LlfPlziZder0djY6P3VVNTE8SK+evq6kJ2djZ27NgxrOmvX7+OJ554AosWLUJlZSXWr1+PNWvWoLS01LcFS3pqt8TmzZvH1q1b5/3e5XKxlJQUtmXLliGn/8lPfsKeeOKJQe/Nnz+frV27NqB1hhpf19vu3buZwWAIUnWhD8BtV+h8129+8xs2bdq0Qe/99Kc/ZYsXL/ZpWSHbAvb29uL06dPIy8vzvieXy5GXl4djx44N+TvHjh0bND0ALF68+I7TRyJ/1hsA2Gw2ZGRkIC0tDStWrMCFCxeCUW7YkmpbC9kA3u1SpaampiF/p6mpyafpI5E/623SpEnYtWsXDh48iHfeeQdutxsPPfQQbt68GYySw9KdtrXOzk7Y7fZhz4f7uaCEv9zc3EEnvz/00EOYMmUK3nrrLbzyyiscK4t8IdsC+nOpkslkGvWXNklxiZdKpcKsWbNQVVUViBIjwp22Nb1ej+jo6GHPJ2QD6M+lSrm5uYOmB4BDhw6NqkubpLjEy+Vy4dy5c0hOTg5UmWFPsm3N1xGiYPrggw+YRqNhJSUl7Ntvv2XPPvssGzNmDGtqamKMMbZy5Uq2ceNG7/RHjx5lSqWS/fGPf2QXL15kmzdvZiqVip07d47Xf4ELX9fbSy+9xEpLS1l1dTU7ffo0e/rpp1lUVBS7cOECr/9C0FmtVlZRUcEqKioYALZ161ZWUVHBampqGGOMbdy4ka1cudI7/bVr15hWq2W//vWv2cWLF9mOHTuYQqFgn332mU/LDekAMsbYn/70J5aens7UajWbN28eO378uPdnCxYsYIWFhYOm/+ijj9jEiROZWq1m06ZNYx9//HGQKw4Nvqy39evXe6c1Go1s6dKlrLy8nEPV/Hz55ZdD3oRJXE+FhYVswYIFt/1OTk4OU6vV7L777mO7d+/2ebl0ORIhHIXsPiAhowEFkBCOKICEcEQBJIQjCiAhHFEACeGIAkgIRxRAQjiiABLCEQWQEI4ogIRw9L+Tf7SVXaR9LQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(2, 2))\n", + "e_posterior = torch.distributions.Normal(mu_posterior, sigma_posterior).sample().flatten().sigmoid()\n", + "sns.kdeplot(e_posterior)\n", + "ax.set_xlim(0, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bronx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scripts/node_classification/run.py b/scripts/node_classification/run.py index 7728ef1..c245fc9 100644 --- a/scripts/node_classification/run.py +++ b/scripts/node_classification/run.py @@ -9,7 +9,6 @@ from ray.air import session import warnings warnings.filterwarnings("ignore") -from bronx.optim import SWA, swap_swa_sgd def get_graph(data): from dgl.data import ( @@ -60,7 +59,7 @@ def get_graph(data): def run(args): pyro.clear_param_store() - torch.cuda.empty_cache() + # torch.cuda.empty_cache() if args.seed > 0: torch.manual_seed(args.seed) @@ -104,99 +103,114 @@ def run(args): model = model.cuda() g = g.to("cuda:0") - optimizer = SWA( + scheduler = pyro.optim.ReduceLROnPlateau( { - "base": getattr(torch.optim, args.optimizer), - "base_args": {"lr": args.learning_rate, "weight_decay": args.weight_decay}, - "swa_args": { - "swa_start": args.swa_start, - "swa_freq": args.swa_freq, - "swa_lr": args.swa_lr, + "optimizer": getattr(torch.optim, args.optimizer), + "optim_args": { + "lr": args.learning_rate, + "weight_decay": args.weight_decay }, - } + "factor": args.lr_factor, + "patience": args.patience, + "mode": "max", + }, ) svi = pyro.infer.SVI( model, model.guide, - optimizer, + scheduler, loss=pyro.infer.TraceMeanField_ELBO( num_particles=args.num_particles, vectorize_particles=True ), ) + accuracy_vl_max = 0.0 + accuracy_te_max = 0.0 for idx in range(args.n_epochs): model.train() loss = svi.step( g, g.ndata["feat"], y=g.ndata["label"], mask=g.ndata["train_mask"] ) - swap_swa_sgd(svi.optim) - model.eval() - with torch.no_grad(): - predictive = pyro.infer.Predictive( - model, - guide=model.guide, - num_samples=args.num_samples, - parallel=True, - return_sites=["_RETURN"], - ) - - y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["val_mask"])[ - "_RETURN" - ].mean(0) - y = g.ndata["label"][g.ndata["val_mask"]] - accuracy_vl = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( - y_hat - ) - - y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["test_mask"])[ - "_RETURN" - ].mean(0) - y = g.ndata["label"][g.ndata["test_mask"]] - accuracy_te = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( - y_hat - ) - + model.eval() + with torch.no_grad(): + predictive = pyro.infer.Predictive( + model, + guide=model.guide, + num_samples=args.num_samples, + parallel=True, + return_sites=["_RETURN"], + ) + + y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["val_mask"])[ + "_RETURN" + ].mean(0) + y = g.ndata["label"][g.ndata["val_mask"]] + accuracy_vl = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( + y_hat + ) + + y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["test_mask"])[ + "_RETURN" + ].mean(0) + y = g.ndata["label"][g.ndata["test_mask"]] + accuracy_te = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( + y_hat + ) + + # print(accuracy_vl, accuracy_te, flush=True) + if next(iter(scheduler.get_state().values()))["optimizer"]["param_groups"][0]["lr"] < 1e-6: + break + scheduler.step(accuracy_vl) + + if accuracy_vl > accuracy_vl_max: + accuracy_vl_max = accuracy_vl + accuracy_te_max = accuracy_te + if args.checkpoint != "": + print(args.checkpoint, flush=True) + torch.save(model, args.checkpoint) + + accuracy_vl = accuracy_vl_max + accuracy_te = accuracy_te_max print("ACCURACY,%.6f,%.6f" % (accuracy_vl, accuracy_te), flush=True) return accuracy_vl, accuracy_te if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--data", type=str, default="CoraGraphDataset") + parser.add_argument("--data", type=str, default="CoauthorCSDataset") parser.add_argument("--hidden_features", type=int, default=32) parser.add_argument("--embedding_features", type=int, default=32) parser.add_argument("--activation", type=str, default="ELU") parser.add_argument("--learning_rate", type=float, default=1e-3) - parser.add_argument("--weight_decay", type=float, default=1e-4) + parser.add_argument("--weight_decay", type=float, default=1e-3) parser.add_argument("--depth", type=int, default=1) - parser.add_argument("--num_samples", type=int, default=8) - parser.add_argument("--num_particles", type=int, default=8) + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--num_particles", type=int, default=4) parser.add_argument("--num_heads", type=int, default=4) parser.add_argument("--sigma_factor", type=float, default=5.0) parser.add_argument("--t", type=float, default=5.0) parser.add_argument("--optimizer", type=str, default="Adam") parser.add_argument("--kl_scale", type=float, default=1e-5) - parser.add_argument("--n_epochs", type=int, default=50) + parser.add_argument("--n_epochs", type=int, default=100) parser.add_argument("--adjoint", type=int, default=1) parser.add_argument("--physique", type=int, default=1) parser.add_argument("--gamma", type=float, default=1.0) parser.add_argument("--readout_depth", type=int, default=1) - parser.add_argument("--dropout_in", type=float, default=0.0) - parser.add_argument("--dropout_out", type=float, default=0.0) + parser.add_argument("--dropout_in", type=float, default=0.5) + parser.add_argument("--dropout_out", type=float, default=0.5) parser.add_argument("--consistency_temperature", type=float, default=0.1) parser.add_argument("--consistency_factor", type=float, default=1e-5) - parser.add_argument("--node_prior", type=int, default=0) - parser.add_argument("--norm", type=int, default=0) + parser.add_argument("--node_prior", type=int, default=1) + parser.add_argument("--norm", type=int, default=1) parser.add_argument("--k", type=int, default=0) parser.add_argument("--checkpoint", type=str, default="") - parser.add_argument("--swa_start", type=int, default=20) - parser.add_argument("--swa_freq", type=int, default=5) - parser.add_argument("--swa_lr", type=float, default=1e-2) parser.add_argument("--seed", type=int, default=-1) parser.add_argument("--patience", type=int, default=10) parser.add_argument("--split_index", type=int, default=-1) parser.add_argument("--edge_recover", default=0.0, type=float) + parser.add_argument("--lr_factor", default=0.5, type=float) + parser.add_argument("--__trial_index__", default=0, type=int) args = parser.parse_args() run(args) diff --git a/scripts/node_classification/tune.py b/scripts/node_classification/tune.py index 37d6739..8de8db4 100644 --- a/scripts/node_classification/tune.py +++ b/scripts/node_classification/tune.py @@ -34,7 +34,7 @@ def experiment(args): "hidden_features": tune.randint(1, 8), "embedding_features": tune.randint(2, 8), "num_heads": tune.randint(4, 32), - "depth": 1, # tune.randint(1, 4), + "depth": 1, "learning_rate": tune.loguniform(1e-5, 1e-2), "weight_decay": tune.loguniform(1e-10, 1e-2), "num_samples": 4, @@ -44,26 +44,23 @@ def experiment(args): "optimizer": "Adam", # tune.choice(["RMSprop", "Adam", "AdamW", "Adamax", "SGD", "Adagrad"]), "activation": "ELU", # tune.choice(["Tanh", "SiLU", "ELU", "Sigmoid", "ReLU"]), "adjoint": 1, # tune.choice([0, 1]), - "physique": 1, - "norm": 0, # tune.choice([0, 1]), - "gamma": 1.0, # tune.uniform(0.0, 1.0), + "physique": tune.choice([0, 1]), + "norm": tune.choice([0, 1]), + "gamma": 1, # tune.uniform(0, 1), "readout_depth": 1, # tune.randint(1, 4), "kl_scale": tune.loguniform(1e-5, 1e-2), "dropout_in": tune.uniform(0.0, 1.0), "dropout_out": tune.uniform(0.0, 1.0), - "consistency_factor": tune.loguniform(1e-2, 1.0), - "consistency_temperature": tune.uniform(0.0, 1.0), - "n_epochs": 100, - "swa_start": tune.randint(20, 30), - "swa_freq": tune.randint(5, 10), - "swa_lr": tune.loguniform(1e-5, 1e-1), - "node_prior": 1, # tune.choice([0, 1]), + "consistency_factor": tune.loguniform(0.01, 1.0), + "consistency_temperature": tune.uniform(0.0, 0.5), + "n_epochs": tune.randint(50, 100), + "node_prior": tune.choice([0, 1]), "edge_recover": 0.0, # tune.loguniform(1e-5, 1e-1), "seed": 2666, - "split_index": 0, "k": 0, + "split_index": -1, "patience": 10, - "checkpoint": "", + "lr_factor": 0.5, } tune_config = tune.TuneConfig( @@ -75,7 +72,7 @@ def experiment(args): run_config = air.RunConfig( name=name, - storage_path=args.data, + # storage_path=args.data, verbose=0, ) @@ -91,6 +88,6 @@ def experiment(args): if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--data", type=str, default="CornellDataset") + parser.add_argument("--data", type=str, default="CoraGraphDataset") args = parser.parse_args() experiment(args) diff --git a/scripts/node_classification/tune.sh b/scripts/node_classification/tune.sh index 4dd797b..40f5221 100644 --- a/scripts/node_classification/tune.sh +++ b/scripts/node_classification/tune.sh @@ -1,6 +1,6 @@ #BSUB -q gpuqueue #BSUB -o %J.stdout -#BSUB -gpu "num=1:j_exclusive=yes" +#BSUB -gpu "num=4:j_exclusive=yes" #BSUB -R "rusage[mem=5/task] span[ptile=1]" #BSUB -W 1:59 #BSUB -n 4 diff --git a/scripts/node_classification/tune_dist.py b/scripts/node_classification/tune_dist.py index 1c86d62..61f5006 100644 --- a/scripts/node_classification/tune_dist.py +++ b/scripts/node_classification/tune_dist.py @@ -5,13 +5,12 @@ import ray from ray import tune, air, train from ray.tune.trainable import session -from ray.tune.search import ConcurrencyLimiter -from ray.tune.search.optuna import OptunaSearch -from ray.tune.schedulers import ASHAScheduler +from ray.tune.search import ConcurrencyLimiter, Repeater +from ray.tune.search.hyperopt import HyperOptSearch import os ray.init(num_cpus=os.cpu_count()) LSF_COMMAND = "bsub -q gpuqueue -gpu " +\ -"\"num=1:j_exclusive=yes\" -R \"rusage[mem=5] span[ptile=1]\" -W 0:05 -Is " +"\"num=1:j_exclusive=yes\" -R \"rusage[mem=5] span[ptile=1]\" -W 0:10 -Is " PYTHON_COMMAND =\ "python /data/chodera/wangyq/bronx/scripts/node_classification/run.py" @@ -56,6 +55,7 @@ def objective(args): output = lsf_submit(command) accuracy, accuracy_te = parse_output(output) session.report({"accuracy": accuracy, "accuracy_te": accuracy_te}) + # return accuracy def experiment(args): name = datetime.now().strftime("%m%d%Y%H%M%S") @@ -64,7 +64,7 @@ def experiment(args): "hidden_features": tune.randint(1, 8), "embedding_features": tune.randint(2, 8), "num_heads": tune.randint(4, 32), - "depth": 1, # tune.randint(1, 4), + "depth": 1, "learning_rate": tune.loguniform(1e-5, 1e-2), "weight_decay": tune.loguniform(1e-10, 1e-2), "num_samples": 4, @@ -74,20 +74,17 @@ def experiment(args): "optimizer": "Adam", # tune.choice(["RMSprop", "Adam", "AdamW", "Adamax", "SGD", "Adagrad"]), "activation": "ELU", # tune.choice(["Tanh", "SiLU", "ELU", "Sigmoid", "ReLU"]), "adjoint": 1, # tune.choice([0, 1]), - "physique": 1, - "norm": 0, # tune.choice([0, 1]), - "gamma": 1.0, # tune.uniform(0.5, 1.0), + "physique": 1, # tune.choice([0, 1]), + "norm": 1, # tune.choice([0, 1]), + "gamma": 1, # tune.uniform(0, 1), "readout_depth": 1, # tune.randint(1, 4), "kl_scale": tune.loguniform(1e-5, 1e-2), "dropout_in": tune.uniform(0.0, 1.0), "dropout_out": tune.uniform(0.0, 1.0), - "consistency_factor": tune.loguniform(1e-2, 1.0), + "consistency_factor": tune.loguniform(0.01, 1.0), "consistency_temperature": tune.uniform(0.0, 0.5), - "n_epochs": tune.randint(50, 70), - "swa_start": tune.randint(10, 20), - "swa_freq": tune.randint(5, 10), - "swa_lr": tune.loguniform(1e-5, 1e-1), - "node_prior": 1, # tune.choice([0, 1]), + "n_epochs": tune.randint(50, 100), + "node_prior": tune.choice([0, 1]), "edge_recover": 0.0, # tune.loguniform(1e-5, 1e-1), "seed": 2666, "k": 0, @@ -96,8 +93,11 @@ def experiment(args): tune_config = tune.TuneConfig( metric="_metric/accuracy", mode="max", - search_alg=ConcurrencyLimiter(OptunaSearch(), args.concurrent), - num_samples=10000, + search_alg=ConcurrencyLimiter( + Repeater(HyperOptSearch(), repeat=3), + args.concurrent + ), + num_samples=3000, ) run_config = air.RunConfig( @@ -106,7 +106,7 @@ def experiment(args): ) tuner = tune.Tuner( - objective, + tune.with_resources(objective, {"cpu": 0.01}), param_space=param_space, tune_config=tune_config, run_config=run_config, @@ -118,6 +118,6 @@ def experiment(args): import argparse parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default="CoraGraphDataset") - parser.add_argument("--concurrent", type=int, default=200) + parser.add_argument("--concurrent", type=int, default=100) args = parser.parse_args() experiment(args) diff --git a/scripts/node_classification/tune_dist.sh b/scripts/node_classification/tune_dist.sh index 45d43fd..2a18a63 100644 --- a/scripts/node_classification/tune_dist.sh +++ b/scripts/node_classification/tune_dist.sh @@ -3,5 +3,5 @@ #BSUB -W 23:59 #BSUB -n 8 -python tune_dist.py --data CoraGraphDataset +python tune_dist.py --data CoauthorCSDataset