Skip to content

Commit 50e446b

Browse files
authored
[FEATURE] add DKVMN (#24)
* [feat] add DKVMN * Update AUTHOR.md * [fix] add the examples of DKVMN and fix tests * [fix] fix __init__.py * [style] optimize code format * [fix] modify the usage of gpu and the epoch num for test * [fix] modify DKVMN.py * [style] modify the annotations to follow numpy-style * [style] change variable name from CamelCase to UnderScoreCase * [docs] add DKVMN.md
1 parent d024d59 commit 50e446b

File tree

12 files changed

+1771
-0
lines changed

12 files changed

+1771
-0
lines changed

AUTHORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88

99
[Fangzhou Yao](https://github.com/fannazya)
1010

11+
[Jie Ouyang](https://github.com/0russwest0)
12+
1113
The starred is the corresponding author

EduKTM/DKVMN/DKVMN.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
# coding: utf-8
2+
# 2022/3/18 @ ouyangjie
3+
4+
5+
import logging
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
import numpy as np
10+
import math
11+
from sklearn import metrics
12+
from tqdm import tqdm
13+
from EduKTM import KTM
14+
15+
16+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
18+
19+
class Cell(nn.Module):
20+
def __init__(self, memory_size, memory_state_dim):
21+
super(Cell, self).__init__()
22+
self.memory_size = memory_size
23+
self.memory_state_dim = memory_state_dim
24+
25+
def addressing(self, control_input, memory):
26+
"""
27+
Parameters
28+
----------
29+
control_input: tensor
30+
embedding vector of input exercise, shape = (batch_size, control_state_dim)
31+
memory: tensor
32+
key memory, shape = (memory_size, memory_state_dim)
33+
34+
Returns
35+
-------
36+
correlation_weight: tensor
37+
correlation weight, shape = (batch_size, memory_size)
38+
"""
39+
similarity_score = torch.matmul(control_input, torch.t(memory))
40+
correlation_weight = F.softmax(similarity_score, dim=1) # Shape: (batch_size, memory_size)
41+
return correlation_weight
42+
43+
def read(self, memory, read_weight):
44+
"""
45+
Parameters
46+
----------
47+
memory: tensor
48+
value memory, shape = (batch_size, memory_size, memory_state_dim)
49+
read_weight: tensor
50+
correlation weight, shape = (batch_size, memory_size)
51+
52+
Returns
53+
-------
54+
read_content: tensor
55+
read content, shape = (batch_size, memory_size)
56+
"""
57+
read_weight = read_weight.view(-1, 1)
58+
memory = memory.view(-1, self.memory_state_dim)
59+
rc = torch.mul(read_weight, memory)
60+
read_content = rc.view(-1, self.memory_size, self.memory_state_dim)
61+
read_content = torch.sum(read_content, dim=1)
62+
return read_content
63+
64+
65+
class WriteCell(Cell):
66+
def __init__(self, memory_size, memory_state_dim):
67+
super(WriteCell, self).__init__(memory_size, memory_state_dim)
68+
self.erase = torch.nn.Linear(memory_state_dim, memory_state_dim, bias=True)
69+
self.add = torch.nn.Linear(memory_state_dim, memory_state_dim, bias=True)
70+
nn.init.kaiming_normal_(self.erase.weight)
71+
nn.init.kaiming_normal_(self.add.weight)
72+
nn.init.constant_(self.erase.bias, 0)
73+
nn.init.constant_(self.add.bias, 0)
74+
75+
def write(self, control_input, memory, write_weight):
76+
"""
77+
Parameters
78+
----------
79+
control_input: tensor
80+
embedding vector of input exercise and students' answer, shape = (batch_size, control_state_dim)
81+
memory: tensor
82+
value memory, shape = (batch_size, memory_size, memory_state_dim)
83+
read_weight: tensor
84+
correlation weight, shape = (batch_size, memory_size)
85+
86+
Returns
87+
-------
88+
new_memory: tensor
89+
updated value memory, shape = (batch_size, memory_size, memory_state_dim)
90+
"""
91+
erase_signal = torch.sigmoid(self.erase(control_input))
92+
add_signal = torch.tanh(self.add(control_input))
93+
erase_reshape = erase_signal.view(-1, 1, self.memory_state_dim)
94+
add_reshape = add_signal.view(-1, 1, self.memory_state_dim)
95+
write_weight_reshape = write_weight.view(-1, self.memory_size, 1)
96+
erase_mult = torch.mul(erase_reshape, write_weight_reshape)
97+
add_mul = torch.mul(add_reshape, write_weight_reshape)
98+
new_memory = memory * (1 - erase_mult) + add_mul
99+
return new_memory
100+
101+
102+
class DKVMNCell(nn.Module):
103+
def __init__(self, memory_size, key_memory_state_dim, value_memory_state_dim, init_key_memory):
104+
super(DKVMNCell, self).__init__()
105+
"""
106+
Parameters
107+
----------
108+
memory_size: int
109+
size of memory
110+
key_memory_state_dim: int
111+
dimension of key memory
112+
value_memory_state_dim: int
113+
dimension of value memory
114+
init_key_memory: tensor
115+
intial key memory
116+
"""
117+
self.memory_size = memory_size
118+
self.key_memory_state_dim = key_memory_state_dim
119+
self.value_memory_state_dim = value_memory_state_dim
120+
121+
self.key_head = Cell(memory_size=self.memory_size, memory_state_dim=self.key_memory_state_dim)
122+
self.value_head = WriteCell(memory_size=self.memory_size, memory_state_dim=self.value_memory_state_dim)
123+
124+
self.key_memory = init_key_memory
125+
self.value_memory = None
126+
127+
def init_value_memory(self, value_memory):
128+
self.value_memory = value_memory
129+
130+
def attention(self, control_input):
131+
correlation_weight = self.key_head.addressing(control_input=control_input, memory=self.key_memory)
132+
return correlation_weight
133+
134+
def read(self, read_weight):
135+
read_content = self.value_head.read(memory=self.value_memory, read_weight=read_weight)
136+
return read_content
137+
138+
def write(self, write_weight, control_input):
139+
value_memory = self.value_head.write(control_input=control_input,
140+
memory=self.value_memory,
141+
write_weight=write_weight)
142+
self.value_memory = nn.Parameter(value_memory.data)
143+
144+
return self.value_memory
145+
146+
147+
class Net(nn.Module):
148+
def __init__(self, n_question, batch_size, key_embedding_dim, value_embedding_dim,
149+
memory_size, key_memory_state_dim, value_memory_state_dim, final_fc_dim, student_num=None):
150+
super(Net, self).__init__()
151+
self.n_question = n_question
152+
self.batch_size = batch_size
153+
self.key_embedding_dim = key_embedding_dim
154+
self.value_embedding_dim = value_embedding_dim
155+
self.memory_size = memory_size
156+
self.key_memory_state_dim = key_memory_state_dim
157+
self.value_memory_state_dim = value_memory_state_dim
158+
self.final_fc_dim = final_fc_dim
159+
self.student_num = student_num
160+
161+
self.input_embed_linear = nn.Linear(self.key_embedding_dim, self.final_fc_dim, bias=True)
162+
self.read_embed_linear = nn.Linear(self.value_memory_state_dim + self.final_fc_dim,
163+
self.final_fc_dim, bias=True)
164+
self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True)
165+
self.init_key_memory = nn.Parameter(torch.randn(self.memory_size, self.key_memory_state_dim))
166+
nn.init.kaiming_normal_(self.init_key_memory)
167+
self.init_value_memory = nn.Parameter(torch.randn(self.memory_size, self.value_memory_state_dim))
168+
nn.init.kaiming_normal_(self.init_value_memory)
169+
170+
self.mem = DKVMNCell(memory_size=self.memory_size, key_memory_state_dim=self.key_memory_state_dim,
171+
value_memory_state_dim=self.value_memory_state_dim, init_key_memory=self.init_key_memory)
172+
173+
value_memory = nn.Parameter(torch.cat([self.init_value_memory.unsqueeze(0) for _ in range(batch_size)], 0).data)
174+
self.mem.init_value_memory(value_memory)
175+
176+
self.q_embed = nn.Embedding(self.n_question + 1, self.key_embedding_dim, padding_idx=0)
177+
self.qa_embed = nn.Embedding(2 * self.n_question + 1, self.value_embedding_dim, padding_idx=0)
178+
179+
def init_params(self):
180+
nn.init.kaiming_normal_(self.predict_linear.weight)
181+
nn.init.kaiming_normal_(self.read_embed_linear.weight)
182+
nn.init.constant_(self.read_embed_linear.bias, 0)
183+
nn.init.constant_(self.predict_linear.bias, 0)
184+
185+
def init_embeddings(self):
186+
187+
nn.init.kaiming_normal_(self.q_embed.weight)
188+
nn.init.kaiming_normal_(self.qa_embed.weight)
189+
190+
def forward(self, q_data, qa_data, target):
191+
192+
batch_size = q_data.shape[0]
193+
seqlen = q_data.shape[1]
194+
q_embed_data = self.q_embed(q_data)
195+
qa_embed_data = self.qa_embed(qa_data)
196+
197+
value_memory = nn.Parameter(torch.cat([self.init_value_memory.unsqueeze(0) for _ in range(batch_size)], 0).data)
198+
self.mem.init_value_memory(value_memory)
199+
200+
slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)
201+
slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1)
202+
203+
value_read_content_l = []
204+
input_embed_l = []
205+
for i in range(seqlen):
206+
# Attention
207+
q = slice_q_embed_data[i].squeeze(1)
208+
correlation_weight = self.mem.attention(q)
209+
210+
# Read Process
211+
read_content = self.mem.read(correlation_weight)
212+
value_read_content_l.append(read_content)
213+
input_embed_l.append(q)
214+
# Write Process
215+
qa = slice_qa_embed_data[i].squeeze(1)
216+
self.mem.write(correlation_weight, qa)
217+
218+
all_read_value_content = torch.cat([value_read_content_l[i].unsqueeze(1) for i in range(seqlen)], 1)
219+
input_embed_content = torch.cat([input_embed_l[i].unsqueeze(1) for i in range(seqlen)], 1)
220+
221+
predict_input = torch.cat([all_read_value_content, input_embed_content], 2)
222+
read_content_embed = torch.tanh(self.read_embed_linear(predict_input.view(batch_size * seqlen, -1)))
223+
224+
pred = self.predict_linear(read_content_embed)
225+
target_1d = target # [batch_size * seq_len, 1]
226+
mask = target_1d.ge(0) # [batch_size * seq_len, 1]
227+
pred_1d = pred.view(-1, 1) # [batch_size * seq_len, 1]
228+
229+
filtered_pred = torch.masked_select(pred_1d, mask)
230+
filtered_target = torch.masked_select(target_1d, mask)
231+
loss = F.binary_cross_entropy_with_logits(filtered_pred, filtered_target)
232+
233+
return loss, torch.sigmoid(filtered_pred), filtered_target
234+
235+
236+
class DKVMN(KTM):
237+
def __init__(self, n_question, batch_size, key_embedding_dim, value_embedding_dim,
238+
memory_size, key_memory_state_dim, value_memory_state_dim, final_fc_dim, student_num=None):
239+
super(DKVMN, self).__init__()
240+
self.batch_size = batch_size
241+
self.n_question = n_question
242+
self.model = Net(n_question, batch_size, key_embedding_dim, value_embedding_dim,
243+
memory_size, key_memory_state_dim, value_memory_state_dim, final_fc_dim, student_num)
244+
245+
def train_epoch(self, epoch, model, params, optimizer, q_data, qa_data):
246+
N = int(math.floor(len(q_data) / params['batch_size']))
247+
248+
pred_list = []
249+
target_list = []
250+
epoch_loss = 0
251+
252+
model.train()
253+
254+
for idx in tqdm(range(N), "Epoch %s" % epoch):
255+
q_one_seq = q_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
256+
qa_batch_seq = qa_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
257+
target = qa_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
258+
259+
target = (target - 1) / params['n_question']
260+
target = np.floor(target)
261+
input_q = torch.LongTensor(q_one_seq).to(device)
262+
input_qa = torch.LongTensor(qa_batch_seq).to(device)
263+
target = torch.FloatTensor(target).to(device)
264+
target_to_1d = torch.chunk(target, params['batch_size'], 0)
265+
target_1d = torch.cat([target_to_1d[i] for i in range(params['batch_size'])], 1)
266+
target_1d = target_1d.permute(1, 0)
267+
268+
model.zero_grad()
269+
loss, filtered_pred, filtered_target = model.forward(input_q, input_qa, target_1d)
270+
loss.backward()
271+
nn.utils.clip_grad_norm_(model.parameters(), params['maxgradnorm'])
272+
optimizer.step()
273+
epoch_loss += loss.item()
274+
275+
right_target = np.asarray(filtered_target.data.tolist())
276+
right_pred = np.asarray(filtered_pred.data.tolist())
277+
pred_list.append(right_pred)
278+
target_list.append(right_target)
279+
280+
all_pred = np.concatenate(pred_list, axis=0)
281+
all_target = np.concatenate(target_list, axis=0)
282+
auc = metrics.roc_auc_score(all_target, all_pred)
283+
all_pred[all_pred >= 0.5] = 1.0
284+
all_pred[all_pred < 0.5] = 0.0
285+
accuracy = metrics.accuracy_score(all_target, all_pred)
286+
287+
return epoch_loss / N, accuracy, auc
288+
289+
def train(self, params, train_data, test_data=None):
290+
q_data, qa_data = train_data
291+
292+
model = self.model
293+
model.init_embeddings()
294+
model.init_params()
295+
optimizer = torch.optim.Adam(params=model.parameters(), lr=params['lr'], betas=(0.9, 0.9))
296+
297+
model.to(device)
298+
299+
all_valid_loss = {}
300+
all_valid_accuracy = {}
301+
all_valid_auc = {}
302+
best_valid_auc = 0
303+
304+
for idx in range(params['max_iter']):
305+
train_loss, train_accuracy, train_auc = self.train_epoch(idx, model, params, optimizer, q_data, qa_data)
306+
print('Epoch %d/%d, loss : %3.5f, auc : %3.5f, accuracy : %3.5f' %
307+
(idx + 1, params['max_iter'], train_loss, train_auc, train_accuracy))
308+
if test_data is not None:
309+
valid_loss, valid_accuracy, valid_auc = self.eval(params, test_data)
310+
all_valid_loss[idx + 1] = valid_loss
311+
all_valid_accuracy[idx + 1] = valid_accuracy
312+
all_valid_auc[idx + 1] = valid_auc
313+
# output the epoch with the best validation auc
314+
if valid_auc > best_valid_auc:
315+
print('valid auc improve: %3.4f to %3.4f' % (best_valid_auc, valid_auc))
316+
best_valid_auc = valid_auc
317+
318+
def eval(self, params, data):
319+
q_data, qa_data = data
320+
model = self.model
321+
N = int(math.floor(len(q_data) / params['batch_size']))
322+
323+
pred_list = []
324+
target_list = []
325+
epoch_loss = 0
326+
model.eval()
327+
328+
for idx in tqdm(range(N), "Evaluating"):
329+
330+
q_one_seq = q_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
331+
qa_batch_seq = qa_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
332+
target = qa_data[idx * params['batch_size']:(idx + 1) * params['batch_size'], :]
333+
334+
target = (target - 1) / params['n_question']
335+
target = np.floor(target)
336+
337+
input_q = torch.LongTensor(q_one_seq).to(device)
338+
input_qa = torch.LongTensor(qa_batch_seq).to(device)
339+
target = torch.FloatTensor(target).to(device)
340+
341+
target_to_1d = torch.chunk(target, params['batch_size'], 0)
342+
target_1d = torch.cat([target_to_1d[i] for i in range(params['batch_size'])], 1)
343+
target_1d = target_1d.permute(1, 0)
344+
345+
loss, filtered_pred, filtered_target = model.forward(input_q, input_qa, target_1d)
346+
347+
right_target = np.asarray(filtered_target.data.tolist())
348+
right_pred = np.asarray(filtered_pred.data.tolist())
349+
pred_list.append(right_pred)
350+
target_list.append(right_target)
351+
epoch_loss += loss.item()
352+
353+
all_pred = np.concatenate(pred_list, axis=0)
354+
all_target = np.concatenate(target_list, axis=0)
355+
356+
auc = metrics.roc_auc_score(all_target, all_pred)
357+
all_pred[all_pred >= 0.5] = 1.0
358+
all_pred[all_pred < 0.5] = 0.0
359+
accuracy = metrics.accuracy_score(all_target, all_pred)
360+
print('valid auc : %3.5f, valid accuracy : %3.5f' % (auc, accuracy))
361+
362+
return epoch_loss / N, accuracy, auc
363+
364+
def save(self, filepath):
365+
torch.save(self.model.state_dict(), filepath)
366+
logging.info("save parameters to %s" % filepath)
367+
368+
def load(self, filepath):
369+
self.model.load_state_dict(torch.load(filepath))
370+
logging.info("load parameters from %s" % filepath)

EduKTM/DKVMN/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2022/3/18 @ ouyangjie
3+
4+
5+
from .DKVMN import DKVMN

EduKTM/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .AKT import AKT
1010
from .LPKT import LPKT
1111
from .GKT import GKT
12+
from .DKVMN import DKVMN

0 commit comments

Comments
 (0)