diff --git a/examples/pytorch/lightgcn/dataset.py b/examples/pytorch/lightgcn/dataset.py index bfb1ddccee96..6603398fb612 100644 --- a/examples/pytorch/lightgcn/dataset.py +++ b/examples/pytorch/lightgcn/dataset.py @@ -77,10 +77,12 @@ def build_train_dataset(self): users = users.to(self.device) pos_i = pos_i.to(self.device) neg_i = neg_i.to(self.device) - users, pos_i, neg_i = shuffle(users, pos_i, neg_i) + # users, pos_i, neg_i = shuffle(users, pos_i, neg_i) # users = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_user.pt", map_location="cpu") # pos_i = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_pos_items.pt", map_location="cpu") # neg_i = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_neg_items.pt", map_location="cpu") - self.train_dataset = TensorDataset(users, pos_i, neg_i) - self.train_loader = DataLoader(self.train_dataset, batch_size=args.batch, shuffle=False) + # 合并为 [N, 3],每行为 (user, pos_i, neg_i) + merged = torch.stack((users, pos_i, neg_i), dim=1) + self.train_dataset = TensorDataset(merged) + self.train_loader = DataLoader(self.train_dataset, batch_size=args.batch, shuffle=True) return self.train_loader diff --git a/examples/pytorch/lightgcn/model.py b/examples/pytorch/lightgcn/model.py index 1766bccc4e89..58518c393a95 100644 --- a/examples/pytorch/lightgcn/model.py +++ b/examples/pytorch/lightgcn/model.py @@ -61,4 +61,4 @@ def bprLoss(self, user_emb, pos_emb, neg_emb, batch_user, batch_pos, batch_neg): reg_loss = (1/2)*(userEmb0.norm(2).pow(2) + posEmb0.norm(2).pow(2) + negEmb0.norm(2).pow(2))/float(len(batch_user)) - return loss.cpu(), reg_loss.cpu() \ No newline at end of file + return loss, reg_loss \ No newline at end of file diff --git a/examples/pytorch/lightgcn/procedure.py b/examples/pytorch/lightgcn/procedure.py index 9da257ae8bc3..3cd06c4c8b2b 100644 --- a/examples/pytorch/lightgcn/procedure.py +++ b/examples/pytorch/lightgcn/procedure.py @@ -10,6 +10,7 @@ import numpy as np from config import args from model import * +import tqdm def train_lightgcn(dataset): g = dataset.graph @@ -29,9 +30,12 @@ def train_lightgcn(dataset): total_batch = n_edges // args.batch + 1 # Epoch start epoch_start = time.time() - - for batch_idx, (batch_user, batch_pos, batch_neg) in enumerate(train_loader): - batch_start = time.time() + last_batch_time = epoch_start + for (batch,) in tqdm.tqdm(train_loader): + batch_start = time.time() + batch_user = batch[:, 0] + batch_pos = batch[:, 1] + batch_neg = batch[:, 2] # Batch start all_emb = model(g) user_emb = all_emb[batch_user] @@ -39,7 +43,6 @@ def train_lightgcn(dataset): neg_emb = all_emb[batch_neg + n_users] loss, reg_loss = model.bprLoss(user_emb, pos_emb, neg_emb, batch_user, batch_pos, batch_neg) - batch_stage1 = time.time() reg_loss = reg_loss*args.decay loss=loss + reg_loss optimizer.zero_grad() @@ -50,7 +53,8 @@ def train_lightgcn(dataset): # Batch end batch_time = time.time() - # print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, stage1 time: {batch_stage1 - batch_start:.4f}s, stage2 time: {batch_time - batch_stage1:.4f}s, Batch time: {batch_time - batch_start:.4f}s, Loss: {loss.item():.4f}") + # print(f"Epoch {epoch+1}, , extra time: {batch_start-last_batch_time:.4f}s, Batch time: {batch_time - batch_start:.4f}s, Loss: {loss.item():.4f}") + last_batch_time = batch_time # Epoch end epoch_time = time.time() - epoch_start