diff --git a/bronx/layers.py b/bronx/layers.py index 959f153..0202579 100644 --- a/bronx/layers.py +++ b/bronx/layers.py @@ -62,9 +62,6 @@ def forward(self, g, h, e): h = h.reshape(*h.shape[:-1], e.shape[-2], -1) g.edata["e"] = e - g = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True) - g.update_all(fn.copy_e("e", "m"), fn.sum("m", "e_sum")) - g.apply_edges(lambda edges: {"e": edges.data["e"] / edges.dst["e_sum"]}) node_shape = h.shape if self.physique is not None: self.odefunc.h0 = h.clone().detach() @@ -95,14 +92,14 @@ def __init__( adjoint=False, physique=False, gamma=1.0, + temperature=1.0, ): super().__init__() - self.fc_mu = torch.nn.Linear(in_features, out_features) - self.fc_log_sigma = torch.nn.Linear(in_features, out_features) + + self.fc_mu = torch.nn.Linear(in_features, in_features) + self.fc_log_sigma = torch.nn.Linear(in_features, in_features) self.fc_k = torch.nn.Linear(in_features, out_features) - # torch.nn.init.constant_(self.fc_k.weight, 1e-5) - # torch.nn.init.constant_(self.fc_log_sigma.weight, 1e-5) - # torch.nn.init.constant_(self.fc_mu.weight, 1e-5) + self.fc_q = torch.nn.Linear(in_features, out_features) self.activation = activation self.idx = idx @@ -111,6 +108,7 @@ def __init__( self.num_heads = num_heads self.sigma_factor = sigma_factor self.kl_scale = kl_scale + self.temperature = temperature self.linear_diffusion = LinearDiffusion( t, adjoint=adjoint, physique=physique, gamma=gamma, ) @@ -118,44 +116,36 @@ def __init__( def guide(self, g, h): g = g.local_var() h0 = h - # h = h - h.mean(-1, keepdims=True) - # h = torch.nn.functional.normalize(h, dim=-1) - mu, log_sigma, k = self.fc_mu(h), self.fc_log_sigma(h), self.fc_k(h) - mu = mu.reshape(*mu.shape[:-1], self.num_heads, -1) - log_sigma = log_sigma.reshape( - *log_sigma.shape[:-1], self.num_heads, -1 - ) - k = k.reshape(*k.shape[:-1], self.num_heads, -1) - parallel = h.dim() == 3 + with pyro.plate(f"nodes{self.idx}", g.number_of_nodes()): + with pyro.poutine.scale(None, self.kl_scale): + h = pyro.sample( + f"h{self.idx}", + pyro.distributions.Normal( + self.fc_mu(h), + self.fc_log_sigma(h).exp(), + ).to_event(1), + ) - if parallel: - mu, log_sigma, k = mu.swapaxes(0, 1), log_sigma.swapaxes(0, 1), k.swapaxes(0, 1) + k = self.fc_k(h) + q = self.fc_q(h) - g.ndata["mu"], g.ndata["log_sigma"], g.ndata["k"] = mu, log_sigma, k - g.apply_edges(dgl.function.u_dot_v("k", "mu", "mu")) - g.apply_edges( - dgl.function.u_dot_v("k", "log_sigma", "log_sigma") - ) - mu, log_sigma = g.edata["mu"], g.edata["log_sigma"] + k = k.reshape(*k.shape[:-1], self.num_heads, -1) + q = q.reshape(*q.shape[:-1], self.num_heads, -1) + parallel = k.dim() == 4 if parallel: - mu, log_sigma = mu.swapaxes(0, 1), log_sigma.swapaxes(0, 1) - - with pyro.plate( - f"edges{self.idx}", g.number_of_edges(), device=g.device - ): - with pyro.poutine.scale(None, self.kl_scale): - e = pyro.sample( - f"e{self.idx}", - pyro.distributions.TransformedDistribution( - pyro.distributions.Normal( - mu, - self.sigma_factor * log_sigma.exp(), - ), - pyro.distributions.transforms.SigmoidTransform(), - ).to_event(2), - ) + k, q = k.swapaxes(0, 1), q.swapaxes(0, 1) + + g.ndata["k"], g.ndata["q"] = k, q + g.apply_edges(dgl.function.u_dot_v("k", "q", "e")) + e = g.edata["e"] + e = e * self.temperature + # e = torch.zeros_like(e) + e = edge_softmax(g, e) + + if parallel: + e = e.swapaxes(0, 1) h = self.linear_diffusion(g, h0, e) return h @@ -163,32 +153,38 @@ def guide(self, g, h): def forward(self, g, h): g = g.local_var() h0 = h - with pyro.plate( - f"edges{self.idx}", g.number_of_edges(), device=g.device - ): + + with pyro.plate(f"nodes{self.idx}", g.number_of_nodes()): with pyro.poutine.scale(None, self.kl_scale): - e = pyro.sample( - f"e{self.idx}", - pyro.distributions.TransformedDistribution( - pyro.distributions.Normal( - torch.zeros( - g.number_of_edges(), - self.num_heads, - 1, - device=g.device, - ), - self.sigma_factor * torch.ones( - g.number_of_edges(), - self.num_heads, - 1, - device=g.device, - ), - ), - pyro.distributions.transforms.SigmoidTransform(), - ).to_event(2), + h = pyro.sample( + f"h{self.idx}", + pyro.distributions.Normal( + torch.ones(self.in_features, device=h.device), + self.sigma_factor, + ).to_event(1), ) - h = self.linear_diffusion(g, h, e) + k = self.fc_k(h) + q = self.fc_q(h) + + k = k.reshape(*k.shape[:-1], self.num_heads, -1) + q = q.reshape(*q.shape[:-1], self.num_heads, -1) + + parallel = k.dim() == 4 + if parallel: + k, q = k.swapaxes(0, 1), q.swapaxes(0, 1) + + g.ndata["k"], g.ndata["q"] = k, q + g.apply_edges(dgl.function.u_dot_v("k", "q", "e")) + e = g.edata["e"] + e = e * self.temperature + e = edge_softmax(g, e) + + + if parallel: + e = e.swapaxes(0, 1) + + h = self.linear_diffusion(g, h0, e) return h class NodeRecover(pyro.nn.PyroModule): diff --git a/bronx/models.py b/bronx/models.py index bfb43e1..9571523 100644 --- a/bronx/models.py +++ b/bronx/models.py @@ -12,17 +12,16 @@ def __init__( embedding_features=None, activation=torch.nn.SiLU(), depth=1, - readout_depth=1, num_heads=4, sigma_factor=1.0, kl_scale=1.0, t=1.0, - alpha=0.1, adjoint=False, physique=False, gamma=1.0, dropout_in=0.0, dropout_out=0.0, + temperature=1.0, ): super().__init__() if embedding_features is None: @@ -30,22 +29,6 @@ def __init__( self.fc_in = torch.nn.Linear(in_features, hidden_features, bias=False) self.fc_out = torch.nn.Linear(hidden_features, out_features, bias=False) - fc_out = [] - for idx in range(readout_depth-1): - fc_out.append(activation) - fc_out.append( - torch.nn.Linear(hidden_features, hidden_features, bias=False) - ) - fc_out.append(activation) - fc_out.append( - torch.nn.Linear(hidden_features, out_features, bias=False) - ) - self.fc_out = torch.nn.Sequential(*fc_out) - - self.alpha = alpha - self.log_alpha = torch.nn.Parameter( - torch.ones(hidden_features) * math.log(alpha) - ) self.activation = activation self.depth = depth @@ -62,6 +45,7 @@ def __init__( adjoint=adjoint, physique=physique, gamma=gamma, + temperature=temperature, ) if idx > 0: @@ -80,6 +64,7 @@ def __init__( # self.edge_recover = EdgeRecover( # hidden_features, embedding_features, scale=edge_recover_scale, # ) + self.dropout_in = torch.nn.Dropout(dropout_in) self.dropout_out = torch.nn.Dropout(dropout_out) @@ -90,6 +75,7 @@ def guide(self, g, h, *args, **kwargs): for idx in range(self.depth): h = getattr(self, f"layer{idx}").guide(g, h) h = self.dropout_out(h) + h = self.fc_out(h) return h def forward(self, g, h, *args, **kwargs): @@ -100,9 +86,6 @@ def forward(self, g, h, *args, **kwargs): for idx in range(self.depth): h = getattr(self, f"layer{idx}")(g, h) h = self.dropout_out(h) - # h = self.fc_out(h) - # self.node_recover(g, h, h0) - # self.edge_recover(g, h) h = self.fc_out(h) return h diff --git a/scripts/node_classification/run.py b/scripts/node_classification/run.py index 296cf94..9da41c0 100644 --- a/scripts/node_classification/run.py +++ b/scripts/node_classification/run.py @@ -24,10 +24,8 @@ def get_graph(data): g = locals()[data](verbose=False)[0] g = dgl.remove_self_loop(g) - # g = dgl.add_self_loop(g) - src, dst = g.edges() - eids = torch.where(src > dst)[0] - g = dgl.remove_edges(g, eids) + g = dgl.add_self_loop(g) + print(g) g.ndata["label"] = torch.nn.functional.one_hot(g.ndata["label"]) if "train_mask" not in g.ndata: @@ -70,7 +68,6 @@ def run(args): hidden_features=args.hidden_features, embedding_features=args.embedding_features, depth=args.depth, - readout_depth=args.readout_depth, num_heads=args.num_heads, sigma_factor=args.sigma_factor, kl_scale=args.kl_scale, @@ -81,6 +78,7 @@ def run(args): gamma=args.gamma, dropout_in=args.dropout_in, dropout_out=args.dropout_out, + temperature=args.temperature, ) if torch.cuda.is_available(): @@ -88,16 +86,20 @@ def run(args): model = model.cuda() g = g.to("cuda:0") - optimizer = SWA( - { - "base": getattr(torch.optim, args.optimizer), - "base_args": {"lr": args.learning_rate, "weight_decay": args.weight_decay}, - "swa_args": { - "swa_start": args.swa_start, - "swa_freq": args.swa_freq, - "swa_lr": args.swa_lr, - }, - } + # optimizer = SWA( + # { + # "base": getattr(torch.optim, args.optimizer), + # "base_args": {"lr": args.learning_rate, "weight_decay": args.weight_decay}, + # "swa_args": { + # "swa_start": args.swa_start, + # "swa_freq": args.swa_freq, + # "swa_lr": args.swa_lr, + # }, + # } + # ) + + optimizer = getattr(pyro.optim, args.optimizer)( + {"lr": args.learning_rate, "weight_decay": args.weight_decay}, ) svi = pyro.infer.SVI( @@ -115,29 +117,32 @@ def run(args): g, g.ndata["feat"], y=g.ndata["label"], mask=g.ndata["train_mask"] ) - model.eval() - swap_swa_sgd(svi.optim) - with torch.no_grad(): - predictive = pyro.infer.Predictive( - model, - guide=model.guide, - num_samples=args.num_samples, - parallel=True, - return_sites=["_RETURN"], - ) + print(loss) - y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["val_mask"])[ - "_RETURN" - ].mean(0) - y = g.ndata["label"][g.ndata["val_mask"]] - accuracy_vl = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( - y_hat - ) + model.eval() + # swap_swa_sgd(svi.optim) + with torch.no_grad(): + predictive = pyro.infer.Predictive( + model, + guide=model.guide, + num_samples=args.num_samples, + parallel=True, + return_sites=["_RETURN"], + ) + + y_hat = predictive(g, g.ndata["feat"], mask=g.ndata["val_mask"])[ + "_RETURN" + ].mean(0) + + y = g.ndata["label"][g.ndata["val_mask"]] + accuracy_vl = float((y_hat.argmax(-1) == y.argmax(-1)).sum()) / len( + y_hat + ) - if len(args.checkpoint) > 1: - torch.save(model, args.checkpoint) + if len(args.checkpoint) > 1: + torch.save(model, args.checkpoint) - print("ACCURACY: %.6f" % accuracy_vl, flush=True) + print("ACCURACY: %.6f" % accuracy_vl, flush=True) return accuracy_vl if __name__ == "__main__": @@ -148,14 +153,14 @@ def run(args): parser.add_argument("--embedding_features", type=int, default=20) parser.add_argument("--activation", type=str, default="SiLU") parser.add_argument("--learning_rate", type=float, default=1e-2) - parser.add_argument("--weight_decay", type=float, default=1e-5) + parser.add_argument("--weight_decay", type=float, default=1e-10) parser.add_argument("--depth", type=int, default=1) parser.add_argument("--num_samples", type=int, default=64) parser.add_argument("--num_particles", type=int, default=32) parser.add_argument("--num_heads", type=int, default=5) - parser.add_argument("--sigma_factor", type=float, default=10.0) - parser.add_argument("--t", type=float, default=5.0) - parser.add_argument("--optimizer", type=str, default="AdamW") + parser.add_argument("--sigma_factor", type=float, default=1e-3) + parser.add_argument("--t", type=float, default=1.0) + parser.add_argument("--optimizer", type=str, default="Adam") parser.add_argument("--kl_scale", type=float, default=1e-5) parser.add_argument("--n_epochs", type=int, default=50) parser.add_argument("--adjoint", type=int, default=0) @@ -166,6 +171,7 @@ def run(args): parser.add_argument("--swa_freq", type=int, default=10) parser.add_argument("--swa_lr", type=float, default=1e-2) parser.add_argument("--epsilon", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--dropout_in", type=float, default=0.0) parser.add_argument("--dropout_out", type=float, default=0.0) parser.add_argument("--checkpoint", type=str, default="")