Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/pytorch/lightgcn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def build_train_dataset(self):
users = users.to(self.device)
pos_i = pos_i.to(self.device)
neg_i = neg_i.to(self.device)
users, pos_i, neg_i = shuffle(users, pos_i, neg_i)
# users, pos_i, neg_i = shuffle(users, pos_i, neg_i)
# users = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_user.pt", map_location="cpu")
# pos_i = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_pos_items.pt", map_location="cpu")
# neg_i = torch.load("/home/xty/dgl-ascend/examples/pytorch/lightgcn/debug_neg_items.pt", map_location="cpu")
self.train_dataset = TensorDataset(users, pos_i, neg_i)
self.train_loader = DataLoader(self.train_dataset, batch_size=args.batch, shuffle=False)
# 合并为 [N, 3],每行为 (user, pos_i, neg_i)
merged = torch.stack((users, pos_i, neg_i), dim=1)
self.train_dataset = TensorDataset(merged)
self.train_loader = DataLoader(self.train_dataset, batch_size=args.batch, shuffle=True)
return self.train_loader
2 changes: 1 addition & 1 deletion examples/pytorch/lightgcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def bprLoss(self, user_emb, pos_emb, neg_emb, batch_user, batch_pos, batch_neg):
reg_loss = (1/2)*(userEmb0.norm(2).pow(2) +
posEmb0.norm(2).pow(2) +
negEmb0.norm(2).pow(2))/float(len(batch_user))
return loss.cpu(), reg_loss.cpu()
return loss, reg_loss
14 changes: 9 additions & 5 deletions examples/pytorch/lightgcn/procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from config import args
from model import *
import tqdm

def train_lightgcn(dataset):
g = dataset.graph
Expand All @@ -29,17 +30,19 @@ def train_lightgcn(dataset):
total_batch = n_edges // args.batch + 1
# Epoch start
epoch_start = time.time()

for batch_idx, (batch_user, batch_pos, batch_neg) in enumerate(train_loader):
batch_start = time.time()
last_batch_time = epoch_start
for (batch,) in tqdm.tqdm(train_loader):
batch_start = time.time()
batch_user = batch[:, 0]
batch_pos = batch[:, 1]
batch_neg = batch[:, 2]
# Batch start
all_emb = model(g)
user_emb = all_emb[batch_user]
pos_emb = all_emb[batch_pos + n_users]
neg_emb = all_emb[batch_neg + n_users]

loss, reg_loss = model.bprLoss(user_emb, pos_emb, neg_emb, batch_user, batch_pos, batch_neg)
batch_stage1 = time.time()
reg_loss = reg_loss*args.decay
loss=loss + reg_loss
optimizer.zero_grad()
Expand All @@ -50,7 +53,8 @@ def train_lightgcn(dataset):

# Batch end
batch_time = time.time()
# print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, stage1 time: {batch_stage1 - batch_start:.4f}s, stage2 time: {batch_time - batch_stage1:.4f}s, Batch time: {batch_time - batch_start:.4f}s, Loss: {loss.item():.4f}")
# print(f"Epoch {epoch+1}, , extra time: {batch_start-last_batch_time:.4f}s, Batch time: {batch_time - batch_start:.4f}s, Loss: {loss.item():.4f}")
last_batch_time = batch_time

# Epoch end
epoch_time = time.time() - epoch_start
Expand Down