Skip to content

Commit da51415

Browse files
authored
Merge pull request #4 from bigdata-ustc/DKT+
[FEATURE] Add DKT+
2 parents dd368eb + 7e76aa6 commit da51415

26 files changed

+1269
-3
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,5 @@ venv.bak/
109109
# Pyre type checker
110110
.pyre/
111111

112-
# User Definition
112+
# User Definition
113+
data/

CHANGE.txt

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
v0.0.5:
2+
* add DKT+
3+
* add some util functions
4+
15
v0.0.4:
26
* fix potential ModuleNotFoundError
37

EduKTM/DKTPlus/DKTPlus.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# coding: utf-8
2+
# 2021/5/25 @ tongshiwei
3+
4+
import logging
5+
import torch
6+
from EduKTM import KTM
7+
from torch import nn
8+
import torch.nn.functional as F
9+
from tqdm import tqdm
10+
from EduKTM.utils import sequence_mask, SLMLoss, tensor2list, pick
11+
from sklearn.metrics import roc_auc_score, accuracy_score
12+
import numpy as np
13+
14+
15+
class DKTNet(nn.Module):
16+
def __init__(self, ku_num, hidden_num, add_embedding_layer=False, embedding_dim=None, dropout=0.0, **kwargs):
17+
super(DKTNet, self).__init__()
18+
self.ku_num = ku_num
19+
self.hidden_dim = hidden_num
20+
self.output_dim = ku_num
21+
if add_embedding_layer is True:
22+
embedding_dim = self.hidden_dim if embedding_dim is None else embedding_dim
23+
self.embeddings = nn.Sequential(
24+
nn.Embedding(ku_num * 2, embedding_dim),
25+
nn.Dropout(kwargs.get("embedding_dropout", 0.2))
26+
)
27+
rnn_input_dim = embedding_dim
28+
else:
29+
self.embeddings = lambda x: F.one_hot(x, num_classes=self.output_dim * 2).float()
30+
rnn_input_dim = ku_num * 2
31+
32+
self.rnn = nn.RNN(rnn_input_dim, hidden_num, 1, batch_first=True, nonlinearity='tanh')
33+
self.fc = nn.Linear(self.hidden_dim, self.output_dim)
34+
self.dropout = nn.Dropout(dropout)
35+
self.sig = nn.Sigmoid()
36+
37+
def forward(self, responses, mask=None, begin_state=None):
38+
responses = self.embeddings(responses)
39+
output, hn = self.rnn(responses)
40+
output = self.sig(self.fc(self.dropout(output)))
41+
if mask is not None:
42+
output = sequence_mask(output, mask)
43+
return output, hn
44+
45+
46+
class DKTPlus(KTM):
47+
def __init__(self, ku_num, hidden_num, net_params: dict = None, loss_params=None):
48+
super(DKTPlus, self).__init__()
49+
self.dkt_net = DKTNet(
50+
ku_num,
51+
hidden_num,
52+
**(net_params if net_params is not None else {})
53+
)
54+
self.loss_params = loss_params if loss_params is not None else {}
55+
56+
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
57+
loss_function = SLMLoss(**self.loss_params)
58+
59+
trainer = torch.optim.Adam(self.dkt_net.parameters(), lr)
60+
61+
for e in range(epoch):
62+
losses = []
63+
for (data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
64+
# convert to device
65+
data: torch.Tensor = data.to(device)
66+
data_mask: torch.Tensor = data_mask.to(device)
67+
label: torch.Tensor = label.to(device)
68+
pick_index: torch.Tensor = pick_index.to(device)
69+
label_mask: torch.Tensor = label_mask.to(device)
70+
71+
# real training
72+
predicted_response, _ = self.dkt_net(data, data_mask)
73+
loss = loss_function(predicted_response, pick_index, label, label_mask)
74+
75+
# back propagation
76+
trainer.zero_grad()
77+
loss.backward()
78+
trainer.step()
79+
80+
losses.append(loss.mean().item())
81+
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))
82+
83+
if test_data is not None:
84+
auc, accuracy = self.eval(test_data)
85+
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
86+
87+
def eval(self, test_data, device="cpu") -> tuple:
88+
self.dkt_net.eval()
89+
y_true = []
90+
y_pred = []
91+
92+
for (data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
93+
# convert to device
94+
data: torch.Tensor = data.to(device)
95+
data_mask: torch.Tensor = data_mask.to(device)
96+
label: torch.Tensor = label.to(device)
97+
pick_index: torch.Tensor = pick_index.to(device)
98+
label_mask: torch.Tensor = label_mask.to(device)
99+
100+
# real evaluating
101+
output, _ = self.dkt_net(data, data_mask)
102+
output = output[:, :-1]
103+
output = pick(output, pick_index.to(output.device))
104+
pred = tensor2list(output)
105+
label = tensor2list(label)
106+
for i, length in enumerate(label_mask.numpy().tolist()):
107+
length = int(length)
108+
y_true.extend(label[i][:length])
109+
y_pred.extend(pred[i][:length])
110+
self.dkt_net.train()
111+
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)
112+
113+
def save(self, filepath) -> ...:
114+
torch.save(self.dkt_net.state_dict(), filepath)
115+
logging.info("save parameters to %s" % filepath)
116+
117+
def load(self, filepath):
118+
self.dkt_net.load_state_dict(torch.load(filepath))
119+
logging.info("load parameters from %s" % filepath)

EduKTM/DKTPlus/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2021/5/25 @ tongshiwei
3+
4+
from .DKTPlus import DKTPlus
5+
from .etl import etl

EduKTM/DKTPlus/etl.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# coding: utf-8
2+
# 2021/5/25 @ tongshiwei
3+
4+
import torch
5+
import json
6+
from tqdm import tqdm
7+
from EduKTM.utils.torch_utils import PadSequence, FixedBucketSampler
8+
9+
10+
def extract(data_src): # pragma: no cover
11+
responses = []
12+
step = 200
13+
with open(data_src) as f:
14+
for line in tqdm(f, "reading data from %s" % data_src):
15+
data = json.loads(line)
16+
for i in range(0, len(data), step):
17+
if len(data[i: i + step]) < 2:
18+
continue
19+
responses.append(data[i: i + step])
20+
21+
return responses
22+
23+
24+
def transform(raw_data, batch_size, num_buckets=100):
25+
# 定义数据转换接口
26+
# raw_data --> batch_data
27+
28+
responses = raw_data
29+
30+
batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
31+
batch = []
32+
33+
def index(r):
34+
correct = 0 if r[1] <= 0 else 1
35+
return r[0] * 2 + correct
36+
37+
for batch_idx in tqdm(batch_idxes, "batchify"):
38+
batch_rs = []
39+
batch_pick_index = []
40+
batch_labels = []
41+
for idx in batch_idx:
42+
batch_rs.append([index(r) for r in responses[idx]])
43+
if len(responses[idx]) <= 1: # pragma: no cover
44+
pick_index, labels = [], []
45+
else:
46+
pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]])
47+
batch_pick_index.append(list(pick_index))
48+
batch_labels.append(list(labels))
49+
50+
max_len = max([len(rs) for rs in batch_rs])
51+
padder = PadSequence(max_len, pad_val=0)
52+
batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])
53+
54+
max_len = max([len(rs) for rs in batch_labels])
55+
padder = PadSequence(max_len, pad_val=0)
56+
batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels])
57+
batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index]
58+
# Load
59+
batch.append(
60+
[torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels),
61+
torch.tensor(batch_pick_index),
62+
torch.tensor(label_mask)])
63+
64+
return batch
65+
66+
67+
def etl(data_src, batch_size, **kwargs): # pragma: no cover
68+
raw_data = extract(data_src)
69+
return transform(raw_data, batch_size, **kwargs)

EduKTM/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .meta import KTM
55
from .KPT import KPT
66
from .DKT import DKT
7+
from .DKTPlus import DKTPlus

EduKTM/utils/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# coding: utf-8
2+
# 2021/5/24 @ tongshiwei
3+
4+
from .utils import *
5+
from .loss import SequenceLogisticMaskLoss as SLMLoss
6+
from .torch_utils import *

EduKTM/utils/loss.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# coding: utf-8
2+
# 2021/5/24 @ tongshiwei
3+
__all__ = ["SequenceLogisticMaskLoss", "LogisticMaskLoss"]
4+
5+
import torch
6+
from torch import nn
7+
8+
from .torch_utils import pick, sequence_mask
9+
10+
11+
class SequenceLogisticMaskLoss(nn.Module):
12+
"""
13+
Notes
14+
-----
15+
The loss has been average, so when call the step method of trainer, batch_size should be 1
16+
"""
17+
18+
def __init__(self, lr=0.0, lw1=0.0, lw2=0.0):
19+
"""
20+
21+
Parameters
22+
----------
23+
lr: reconstruction
24+
lw1
25+
lw2
26+
"""
27+
super(SequenceLogisticMaskLoss, self).__init__()
28+
self.lr = lr
29+
self.lw1 = lw1
30+
self.lw2 = lw2
31+
self.loss = torch.nn.BCELoss(reduction='none')
32+
33+
def forward(self, pred_rs, pick_index, label, label_mask):
34+
if self.lw1 > 0.0 or self.lw2 > 0.0:
35+
post_pred_rs = pred_rs[:, 1:]
36+
pre_pred_rs = pred_rs[:, :-1]
37+
diff = post_pred_rs - pre_pred_rs
38+
diff = sequence_mask(diff, label_mask)
39+
w1 = torch.mean(torch.norm(diff, 1, -1)) / diff.shape[-1]
40+
w2 = torch.mean(torch.norm(diff, 2, -1)) / diff.shape[-1]
41+
# w2 = F.mean(F.sqrt(diff ** 2))
42+
w1 = w1 * self.lw1 if self.lw1 > 0.0 else 0.0
43+
w2 = w2 * self.lw2 if self.lw2 > 0.0 else 0.0
44+
else:
45+
w1 = 0.0
46+
w2 = 0.0
47+
48+
if self.lr > 0.0:
49+
re_pred_rs = pred_rs[:, 1:]
50+
re_pred_rs = pick(re_pred_rs, pick_index)
51+
wr = sequence_mask(self.loss(re_pred_rs, label.float()), label_mask)
52+
wr = torch.mean(wr) * self.lr
53+
else:
54+
wr = 0.0
55+
56+
pred_rs = pred_rs[:, 1:]
57+
pred_rs = pick(pred_rs, pick_index)
58+
loss = sequence_mask(self.loss(pred_rs, label.float()), label_mask)
59+
# loss = F.sum(loss, axis=-1)
60+
loss = torch.mean(loss) + w1 + w2 + wr
61+
return loss
62+
63+
64+
class LogisticMaskLoss(nn.Module): # pragma: no cover
65+
"""
66+
Notes
67+
-----
68+
The loss has been average, so when call the step method of trainer, batch_size should be 1
69+
"""
70+
71+
def __init__(self):
72+
super(LogisticMaskLoss, self).__init__()
73+
74+
self.loss = torch.nn.BCELoss()
75+
76+
def forward(self, pred_rs, label, label_mask, *args, **kwargs):
77+
loss = sequence_mask(self.loss(pred_rs, label), label_mask)
78+
loss = torch.mean(loss)
79+
return loss

EduKTM/utils/tests.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# coding: utf-8
2+
# 2021/5/26 @ tongshiwei
3+
4+
def pseudo_data_generation(ku_num, record_num=10, max_length=20):
5+
# 在这里定义测试用伪数据流
6+
import random
7+
random.seed(10)
8+
9+
raw_data = [
10+
[
11+
(random.randint(0, ku_num - 1), random.randint(-1, 1))
12+
for _ in range(random.randint(2, max_length))
13+
] for _ in range(record_num)
14+
]
15+
16+
return raw_data

EduKTM/utils/torch_utils/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2021/5/25 @ tongshiwei
3+
4+
from .extlib import *
5+
from .functional import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2021/5/26 @ tongshiwei
3+
4+
from .data import *
5+
from .sampler import *
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# coding: utf-8
2+
# 2021/5/25 @ tongshiwei
3+
# These codes are modified from gluonnlp
4+
5+
__all__ = ["PadSequence"]
6+
7+
8+
class PadSequence:
9+
"""Pad the sequence.
10+
11+
Pad the sequence to the given `length` by inserting `pad_val`. If `clip` is set,
12+
sequence that has length larger than `length` will be clipped.
13+
14+
Parameters
15+
----------
16+
length : int
17+
The maximum length to pad/clip the sequence
18+
pad_val : number
19+
The pad value. Default 0
20+
clip : bool
21+
"""
22+
23+
def __init__(self, length, pad_val=0, clip=True):
24+
self._length = length
25+
self._pad_val = pad_val
26+
self._clip = clip
27+
28+
def __call__(self, sample):
29+
"""
30+
31+
Parameters
32+
----------
33+
sample : list of number or mx.nd.NDArray or np.ndarray
34+
35+
Returns
36+
-------
37+
ret : list of number or mx.nd.NDArray or np.ndarray
38+
"""
39+
sample_length = len(sample)
40+
if sample_length >= self._length:
41+
if self._clip and sample_length > self._length:
42+
return sample[:self._length]
43+
else:
44+
return sample
45+
else:
46+
return sample + [
47+
self._pad_val for _ in range(self._length - sample_length)
48+
]

0 commit comments

Comments
 (0)