Skip to content

Sde #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,10 @@ dmypy.json

# Pyre type checker
.pyre/

# ssh
*.stdout
*.out

# data
*.csv
90 changes: 80 additions & 10 deletions bronx/layers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,91 @@
from typing import Optional, Callable
import torch
import dgl
from dgl.nn import GraphConv
from dgl.nn import DotGatConv, GraphConv, GATConv
import torchsde

class BronxLayer(torchsde.SDEIto):
def __init__(self, hidden_features):
super().__init__(noise_type="general")
self.gcn = GraphConv(hidden_features + 1, hidden_features)
# class GraphConv(torch.nn.Module):
# def __init__(
# self,
# in_features: int,
# out_features: int,
# **kwargs,
# ):
# super().__init__()
# # self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
# # torch.nn.init.xavier_uniform_(self.W)

# def forward(self, graph, x):
# with graph.local_scope():
# norm = torch.pow(graph.in_degrees().float().clamp(min=1), -1).unsqueeze(-1)
# graph.ndata["x"] = x * norm
# graph.update_all(dgl.function.copy_src(src="x", out="m"), dgl.function.sum(msg="m", out="x"))
# x = graph.ndata["x"]
# # x = x @ self.W.tanh()
# return x


class _GraphConv(GraphConv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
in_feats, out_feats = self.weight.shape
del self.weight
self.w = torch.nn.Parameter(torch.eye(in_feats))
self.d = torch.nn.Parameter(torch.zeros(in_feats))

@property
def weight(self):
d = self.d.sigmoid()
w = torch.mm(self.w * d, self.w.T)
return w

class BronxLayer(torchsde.SDEStratonovich):
def __init__(self, hidden_features, num_heads=1, gamma=0.0, gain=0.0):
super().__init__(noise_type="scalar")
# self.gcn = GraphConv(hidden_features, hidden_features // 4)
# self.gcn2 = _GraphConv(hidden_features, hidden_features)
self.gcn = GraphConv(hidden_features, hidden_features // 2)
self.fc_mu = torch.nn.Linear(2, hidden_features)
self.fc_log_sigma = torch.nn.Linear(2, hidden_features)
self.w = torch.nn.Parameter(torch.zeros(hidden_features, hidden_features))
torch.nn.init.xavier_uniform_(self.w, gain=gain)

self.graph = None
self.graph2 = None
# self.graph3 = None
self.num_heads = num_heads
self.gamma = gamma


def ty(self, t, y):
# y = torch.nn.functional.normalize(y, dim=-1)
t = torch.broadcast_to(t, (*y.shape[:-1], 1))
ty = torch.cat([t.cos(), t.sin(), y], dim=-1)
return ty

def f(self, t, y):
t = torch.broadcast_to(t, (*y.shape[:-1], 1))
return self.gcn(
self.graph,
torch.cat([t, y], dim=-1),
) - y
t = torch.cat([t.cos(), t.sin()], dim=-1)
y = torch.nn.functional.normalize(y, dim=-1)
mu = self.fc_mu(t).sigmoid()
w = self.w - self.w.T
y1 = self.gcn(self.graph, y)# .tanh()
y2 = self.gcn(self.graph2, y)# .tanh()

if y1.dim() > 2:
y1 = y1.flatten(-2, -1)
y2 = y2.flatten(-2, -1)

y12 = torch.cat([y1, y2], dim=-1)
# y12 = torch.nn.functional.silu(y12)
y = y @ w + y12 - self.gamma * y
return torch.nn.functional.tanh(y) * mu

def g(self, t, y):
return 1e-2 * torch.ones_like(y).unsqueeze(-1)
t = torch.broadcast_to(t, (*y.shape[:-1], 1))
t = torch.cat([t.cos(), t.sin()], dim=-1)
return self.fc_log_sigma(t).unsqueeze(-1)

def h(self, t, y):
return -y

44 changes: 38 additions & 6 deletions bronx/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,49 @@
import torch
from .layers import BronxLayer
import torchsde
from torchsde import BrownianInterval

class BronxModel(torch.nn.Module):
def __init__(self, in_features, hidden_features, out_features):
def __init__(
self,
in_features, hidden_features, out_features, num_heads=1,
dropout0=0.0, dropout1=0.0, gamma=0.0, gain=0.0,
):
super().__init__()
self.fc_in = torch.nn.Linear(in_features, hidden_features, bias=False)
self.fc_out = torch.nn.Linear(hidden_features, out_features, bias=False)
self.sde = BronxLayer(hidden_features)
self.sde = BronxLayer(hidden_features, gamma=gamma, gain=gain)
self.dropout0 = torch.nn.Dropout(dropout0)
self.dropout1 = torch.nn.Dropout(dropout1)

def forward(self, g, h):
self.sde.graph = g
h = self.fc_in(h)
h = torchsde.sdeint(self.sde, h, torch.tensor([0.0, 1.0]))[-1]
# self.sde.graph = g
h = self.dropout0(h)
h = self.fc_in(h)# .tanh()
t = torch.tensor([0.0, 1.0], device=h.device, dtype=h.dtype)
h = torchsde.sdeint(
self.sde,
h,
t,
# bm=BrownianInterval(
# t0=t[0],
# t1=t[-1],
# size=(h.shape[0], 1),
# device=h.device,
# cache_size=None,
# pool_size=4,
# ),
dt=0.05,
logqp=self.training,
)

if self.training:
h, kl = h
else:
kl = 0.0

h = h[-1]
# h = torch.nn.functional.silu(h)
h = self.dropout1(h)
h = self.fc_out(h)
return h
return h, kl
14 changes: 14 additions & 0 deletions scripts/planetoid/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import ray
from ray.tune import ExperimentAnalysis

def run(path):
analysis = ExperimentAnalysis(path)
trial = analysis.get_best_trial("_metric/mean_accuracy", "max")
accuracy = analysis.results[trial.trial_id]["_metric"]["mean_accuracy"]
print(accuracy)

if __name__ == "__main__":
import sys
run(sys.argv[1])


43 changes: 43 additions & 0 deletions scripts/planetoid/meta_run.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
for hidden_features in 128; do
for learning_rate in 1e-2; do
for weight_decay in 1e-4; do
for num_heads in 4; do
for dropout0 in 0.6; do
for dropout1 in 0.6; do
for gamma in 0.2; do
for factor in 0.8; do
for patience in 10; do
for gain in 0.1; do


sbatch \
--nodes=1 \
--ntasks-per-node=1 \
--cpus-per-task=1 \
--time=00:30:00 \
--mem=5GB \
--job-name=aperol \
--gres=gpu:v100:1 \
--output=%A.out \
--wrap "python run.py \
--hidden_features $hidden_features \
--learning_rate $learning_rate \
--weight_decay $weight_decay \
--num_heads $num_heads \
--dropout0 $dropout0 \
--dropout1 $dropout1 \
--gamma $gamma \
--factor $factor \
--patience $patience \
--gain $gain"

done
done
done
done
done
done
done
done
done
done
59 changes: 35 additions & 24 deletions scripts/planetoid/meta_run.sh
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
for hidden_features in 32 64 128; do
for learning_rate in 1e-3; do
for depth in 2 3 4; do
for residual in 1; do
for weight_decay in 1e-7 1e-8 1e-9; do
for semantic_weight in -1; do
for num_heads in 2 4; do
for fc_dropout in 0.0; do
for a_h_dropout in 0.0; do
for a_x_dropout in 0.0; do
for epsilon in 1.0; do
for hidden_features in 256; do
for learning_rate in 1e-2; do
for weight_decay in 1e-4; do
for num_heads in 4; do
for dropout0 in 0.2 0.4 0.6; do
for dropout1 in 0.2 0.4 0.6; do
for gamma in 0.2 0.4 0.6 0.8; do
for patience in 5 10 15; do
for factor in 0.4 0.6 0.8; do

hidden_features=$hidden_features \
learning_rate=$learning_rate \
depth=$depth \
residual=$residual \
weight_decay=$weight_decay \
semantic_weight=$semantic_weight \
num_heads=$num_heads \
fc_dropout=$fc_dropout \
a_h_dropout=$a_h_dropout \
a_x_dropout=$a_x_dropout \
epsilon=$epsilon \
bsub < run.sh
bsub \
-q gpuqueue \
-o %J.stdout \
-gpu "num=1:j_exclusive=yes" \
-R "rusage[mem=5] span[ptile=1]" \
-W 0:5 \
-n 1 \
python run.py \
--hidden_features $hidden_features \
--learning_rate $learning_rate \
--weight_decay $weight_decay \
--num_heads $num_heads \
--dropout0 $dropout0 \
--dropout1 $dropout1 \
--gamma $gamma \
--patience $patience \
--factor $factor

done; done; done; done; done; done; done; done; done; done; done
done
done
done
done
done
done
done
done
done
6 changes: 6 additions & 0 deletions scripts/planetoid/performance.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,data,hidden_features,learning_rate,weight_decay,num_heads,dropout0,dropout1,gamma,patience,factor,gain,accuracy_vl,accuracy_te
0,cora,128,0.01,0.0001,4.0,0.6,0.6,0.2,10,0.8,0.1,0.788,0.817
0,cora,128,0.01,0.0001,4.0,0.6,0.6,0.2,10,0.8,0.1,0.788,0.786
0,cora,128,0.01,0.0001,4.0,0.6,0.6,0.2,10,0.8,0.1,0.802,0.792
0,cora,128,0.01,0.0001,4.0,0.6,0.6,0.2,10,0.8,0.1,0.784,0.794
0,cora,128,0.01,0.0001,4.0,0.6,0.6,0.2,10,0.8,0.1,0.776,0.793
49 changes: 36 additions & 13 deletions scripts/planetoid/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash
import numpy as np
import torch
import dgl
Expand All @@ -8,42 +9,60 @@ def run(args):
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
g = locals()[f"{args.data.capitalize()}GraphDataset"]()[0]
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
# g.ndata["feat"] = torch.cat([g.ndata["feat"], dgl.laplacian_pe(g, 10)], dim=-1)
# g = dgl.add_self_loop(g)

model = BronxModel(
in_features=g.ndata["feat"].shape[-1],
out_features=g.ndata["label"].max() + 1,
hidden_features=args.hidden_features,
num_heads=args.num_heads,
dropout0=args.dropout0,
dropout1=args.dropout1,
gamma=args.gamma,
gain=args.gain,
)

if torch.cuda.is_available():
a = a.cuda()
model = model.cuda()
g = g.to("cuda:0")

model.sde.graph = g
model.sde.graph2 = dgl.khop_graph(g, 2)
# model.sde.graph3 = dgl.khop_graph(g, 3)
# model.sde.graph4 = dgl.khop_graph(g, 4)
# model.sde.graph3 = dgl.khop_graph(g, 3)
# model.sde.graph4 = dgl.khop_graph(g, 4)
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=args.patience, factor=args.factor)
accuracy_vl = []
accuracy_te = []

# import tqdm
for _ in range(1000):
for idx_epoch in range(100):
model.train()
optimizer.zero_grad()
y_hat = model(g, g.ndata['feat'])[g.ndata['train_mask']]
y_hat, kl = model(g, g.ndata['feat'])
y_hat = y_hat[g.ndata['train_mask']]
kl = kl.squeeze(-2)[g.ndata['train_mask']]
y = g.ndata['label'][g.ndata['train_mask']]
loss = torch.nn.CrossEntropyLoss()(y_hat, y)
kl = kl.mean()
loss = torch.nn.CrossEntropyLoss()(y_hat, y) + kl
loss.backward()
optimizer.step()
# scheduler.step()
model.eval()

with torch.no_grad():
y_hat = torch.stack([model(g, g.ndata["feat"])[g.ndata["val_mask"]] for _ in range(args.n_samples)]).mean(0)
_y_hat = torch.stack([model(g, g.ndata["feat"])[0] for _ in range(4)], 0).mean(0)
y_hat = _y_hat[g.ndata["val_mask"]]
y = g.ndata["label"][g.ndata["val_mask"]]
accuracy = float((y_hat.argmax(-1) == y).sum()) / len(y_hat)
print(accuracy, kl.item(), flush=True)
accuracy_vl.append(accuracy)
print(accuracy)
scheduler.step(accuracy)

y_hat = torch.stack([model(g, g.ndata["feat"])[g.ndata["test_mask"]] for _ in range(args.n_samples)]).mean(0)
y_hat = _y_hat[g.ndata["test_mask"]]
y = g.ndata["label"][g.ndata["test_mask"]]
accuracy = float((y_hat.argmax(-1) == y).sum()) / len(y_hat)
accuracy_te.append(accuracy)
Expand All @@ -66,12 +85,16 @@ def run(args):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="cora")
parser.add_argument("--hidden_features", type=int, default=16)
parser.add_argument("--hidden_features", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=1e-2)
parser.add_argument("--depth", type=int, default=2)
parser.add_argument("--residual", type=int, default=1)
parser.add_argument("--weight_decay", type=float, default=1e-10)
parser.add_argument("--n_samples", type=int, default=4)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--num_heads", type=float, default=4)
parser.add_argument("--dropout0", type=float, default=0.6)
parser.add_argument("--dropout1", type=float, default=0.6)
parser.add_argument("--gamma", type=float, default=0.2)
parser.add_argument("--patience", type=int, default=15)
parser.add_argument("--factor", type=float, default=0.4)
parser.add_argument("--gain", type=float, default=0.1)
args = parser.parse_args()
print(args)
run(args)
Loading