-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
112 lines (90 loc) · 3.5 KB
/
utils.py
File metadata and controls
112 lines (90 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import random
import pandas as pd
import argparse
import json
import os
import ast
from tqdm import tqdm
class SeqItemDataset(Dataset):
def __init__(self, csv_path, max_len=20):
self.data = pd.read_csv(csv_path)
self.max_len = max_len
self.records = []
file_name = os.path.basename(csv_path)
for idx, row in tqdm(self.data.iterrows(),
desc=f"Loading {file_name}",
total=len(self.data),
ncols=100):
uid = row['uid']
pos = row['pos_iid']
neg = row['neg_iid']
try:
src = ast.literal_eval(row['pos_seq']) if 'pos_seq' in row else []
src_ids = [int(i) for i in src][:max_len]
if len(src_ids) < max_len:
src_ids += [0] * (max_len - len(src_ids))
self.records.append((uid, src_ids, pos, neg))
except:
continue
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
uid, src, pos, neg = self.records[idx]
return torch.tensor(uid), torch.tensor(src), torch.tensor(pos), torch.tensor(neg)
class TestSeqItemDataset(Dataset):
def __init__(self, csv_path, max_len=20):
self.data = pd.read_csv(csv_path)
self.max_len = max_len
self.records = []
file_name = os.path.basename(csv_path)
for idx, row in tqdm(self.data.iterrows(),
desc=f"Loading Eval {file_name}",
total=len(self.data),
ncols=100):
try:
uid = row['uid']
pos = row['pos_iid']
neg = row['neg_iid']
src = []
if 'pos_seq' in row and pd.notna(row['pos_seq']):
src = ast.literal_eval(row['pos_seq'])
src_ids = [int(i) for i in src][:max_len]
if len(src_ids) < max_len:
src_ids += [0] * (max_len - len(src_ids))
tgt_iids = []
if 'tgt_iids' in row and pd.notna(row['tgt_iids']):
tgt_iids = ast.literal_eval(row['tgt_iids'])
self.records.append((uid, src_ids, pos, neg, tgt_iids))
except Exception as e:
print(f"Error processing row {idx}: {e}")
continue
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
uid, src, pos, neg, tgt_iids = self.records[idx]
return (
torch.tensor(uid),
torch.tensor(src),
torch.tensor(pos),
torch.tensor(neg),
torch.tensor(tgt_iids) if tgt_iids else torch.tensor([])
)
def sample_candidates(pos_id, pos_ids, MIN, MAX, neg_sample_size, seed=None):
rng = np.random.RandomState(seed)
# print(f'pos_id: {pos_id}, src_ids: {src_ids}, MIN: {MIN}, MAX: {MAX}, neg_sample_size: {neg_sample_size}')
pos = int(pos_id)
exclude_ids = set(int(i) for i in pos_ids)
all_candidates = np.arange(MIN, MAX + 1)
mask = ~np.isin(all_candidates, list(exclude_ids))
valid_candidates = all_candidates[mask]
negs = rng.choice(
valid_candidates,
size=neg_sample_size,
replace=False
)
cand_ids = np.concatenate([[pos], negs])
return cand_ids