Skip to content

Commit d024d59

Browse files
authored
[FEATURE] add GKT (#22)
* add GKT * Update AUTHORS.md * Delete mgkt * add GKT docs
1 parent b651a35 commit d024d59

File tree

16 files changed

+942
-0
lines changed

16 files changed

+942
-0
lines changed

AUTHORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66

77
[Xiaonan Zeng](https://github.com/sone47)
88

9+
[Fangzhou Yao](https://github.com/fannazya)
10+
911
The starred is the corresponding author

EduKTM/GKT/GKT.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# coding: utf-8
2+
# 2022/2/25 @ fannazya
3+
4+
import logging
5+
import numpy as np
6+
import torch
7+
from tqdm import tqdm
8+
from EduKTM import KTM
9+
from .GKTNet import GKTNet
10+
from EduKTM.utils import SLMLoss, tensor2list, pick
11+
from sklearn.metrics import roc_auc_score, accuracy_score
12+
13+
14+
class GKT(KTM):
15+
def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_params=None):
16+
super(GKT, self).__init__()
17+
self.gkt_model = GKTNet(
18+
ku_num,
19+
graph,
20+
hidden_num,
21+
**(net_params if net_params is not None else {})
22+
)
23+
# self.gkt_model = GKTNet(ku_num, graph, hidden_num)
24+
self.loss_params = loss_params if loss_params is not None else {}
25+
26+
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
27+
loss_function = SLMLoss(**self.loss_params)
28+
trainer = torch.optim.Adam(self.gkt_model.parameters(), lr)
29+
30+
for e in range(epoch):
31+
losses = []
32+
for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
33+
# convert to device
34+
question: torch.Tensor = question.to(device)
35+
data: torch.Tensor = data.to(device)
36+
data_mask: torch.Tensor = data_mask.to(device)
37+
label: torch.Tensor = label.to(device)
38+
pick_index: torch.Tensor = pick_index.to(device)
39+
label_mask: torch.Tensor = label_mask.to(device)
40+
41+
# real training
42+
predicted_response, _ = self.gkt_model(question, data, data_mask)
43+
44+
loss = loss_function(predicted_response, pick_index, label, label_mask)
45+
46+
# back propagation
47+
trainer.zero_grad()
48+
loss.backward()
49+
trainer.step()
50+
51+
losses.append(loss.mean().item())
52+
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))
53+
54+
if test_data is not None:
55+
auc, accuracy = self.eval(test_data)
56+
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
57+
58+
def eval(self, test_data, device="cpu") -> tuple:
59+
self.gkt_model.eval()
60+
y_true = []
61+
y_pred = []
62+
63+
for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
64+
# convert to device
65+
question: torch.Tensor = question.to(device)
66+
data: torch.Tensor = data.to(device)
67+
data_mask: torch.Tensor = data_mask.to(device)
68+
label: torch.Tensor = label.to(device)
69+
pick_index: torch.Tensor = pick_index.to(device)
70+
label_mask: torch.Tensor = label_mask.to(device)
71+
72+
# real evaluating
73+
output, _ = self.gkt_model(question, data, data_mask)
74+
output = output[:, :-1]
75+
output = pick(output, pick_index.to(output.device))
76+
pred = tensor2list(output)
77+
label = tensor2list(label)
78+
for i, length in enumerate(label_mask.numpy().tolist()):
79+
length = int(length)
80+
y_true.extend(label[i][:length])
81+
y_pred.extend(pred[i][:length])
82+
self.gkt_model.train()
83+
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)
84+
85+
def save(self, filepath) -> ...:
86+
torch.save(self.gkt_model.state_dict(), filepath)
87+
logging.info("save parameters to %s" % filepath)
88+
89+
def load(self, filepath):
90+
self.gkt_model.load_state_dict(torch.load(filepath))
91+
logging.info("load parameters from %s" % filepath)

EduKTM/GKT/GKTNet.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# coding: utf-8
2+
# 2022/3/1 @ fannazya
3+
__all__ = ["GKTNet"]
4+
5+
import json
6+
import networkx as nx
7+
import torch
8+
from torch import nn
9+
import torch.nn.functional as F
10+
from EduKTM.utils import GRUCell, begin_states, get_states, expand_tensor, \
11+
format_sequence, mask_sequence_variable_length
12+
13+
14+
class GKTNet(nn.Module):
15+
def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0):
16+
super(GKTNet, self).__init__()
17+
self.ku_num = int(ku_num)
18+
self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num)
19+
self.latent_dim = self.ku_num if latent_dim is None else int(latent_dim)
20+
self.neighbor_dim = self.hidden_num + self.latent_dim
21+
self.graph = nx.DiGraph()
22+
self.graph.add_nodes_from(list(range(ku_num)))
23+
try:
24+
with open(graph) as f:
25+
self.graph.add_weighted_edges_from(json.load(f))
26+
except ValueError:
27+
with open(graph) as f:
28+
self.graph.add_weighted_edges_from([e + [1.0] for e in json.load(f)])
29+
30+
self.rnn = GRUCell(self.hidden_num)
31+
self.response_embedding = nn.Embedding(2 * self.ku_num, self.latent_dim)
32+
self.concept_embedding = nn.Embedding(self.ku_num, self.latent_dim)
33+
self.f_self = nn.Linear(self.neighbor_dim, self.hidden_num)
34+
self.n_out = nn.Linear(2 * self.neighbor_dim, self.hidden_num)
35+
self.n_in = nn.Linear(2 * self.neighbor_dim, self.hidden_num)
36+
self.dropout = nn.Dropout(dropout)
37+
self.out = nn.Linear(self.hidden_num, 1)
38+
39+
def in_weight(self, x, ordinal=True, with_weight=True):
40+
if isinstance(x, torch.Tensor):
41+
x = x.numpy().tolist()
42+
if isinstance(x, list):
43+
return [self.in_weight(_x) for _x in x]
44+
elif isinstance(x, (int, float)):
45+
if not ordinal:
46+
return list(self.graph.predecessors(int(x)))
47+
else:
48+
_ret = [0] * self.ku_num
49+
for i in self.graph.predecessors(int(x)):
50+
if with_weight:
51+
_ret[i] = self.graph[i][x]['weight']
52+
else:
53+
_ret[i] = 1
54+
return _ret
55+
else:
56+
raise TypeError("cannot handle %s" % type(x))
57+
58+
def out_weight(self, x, ordinal=True, with_weight=True):
59+
if isinstance(x, torch.Tensor):
60+
x = x.numpy().tolist()
61+
if isinstance(x, list):
62+
return [self.out_weight(_x) for _x in x]
63+
elif isinstance(x, (int, float)):
64+
if not ordinal:
65+
return list(self.graph.successors(int(x)))
66+
else:
67+
_ret = [0] * self.ku_num
68+
for i in self.graph.successors(int(x)):
69+
if with_weight:
70+
_ret[i] = self.graph[x][i]['weight']
71+
else:
72+
_ret[i] = 1
73+
return _ret
74+
else:
75+
raise TypeError("cannot handle %s" % type(x))
76+
77+
def neighbors(self, x, ordinal=True, with_weight=False):
78+
if isinstance(x, torch.Tensor):
79+
x = x.numpy().tolist()
80+
if isinstance(x, list):
81+
return [self.neighbors(_x) for _x in x]
82+
elif isinstance(x, (int, float)):
83+
if not ordinal:
84+
return list(self.graph.neighbors(int(x)))
85+
else:
86+
_ret = [0] * self.ku_num
87+
for i in self.graph.neighbors(int(x)):
88+
if with_weight:
89+
_ret[i] = self.graph[i][x]['weight']
90+
else:
91+
_ret[i] = 1
92+
return _ret
93+
else:
94+
raise TypeError("cannot handle %s" % type(x))
95+
96+
def forward(self, questions, answers, valid_length=None, compressed_out=True, layout="NTC"):
97+
length = questions.shape[1]
98+
inputs, axis, batch_size = format_sequence(length, questions, layout, False)
99+
answers, _, _ = format_sequence(length, answers, layout, False)
100+
101+
states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0]
102+
outputs = []
103+
all_states = []
104+
for i in range(length):
105+
# neighbors - aggregate
106+
inputs_i = inputs[i].reshape([batch_size, ])
107+
answer_i = answers[i].reshape([batch_size, ])
108+
109+
_neighbors = self.neighbors(inputs_i)
110+
neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num)
111+
_neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num + self.latent_dim)
112+
113+
# get concept embedding
114+
concept_embeddings = self.concept_embedding.weight.data
115+
concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size)
116+
117+
agg_states = torch.cat((concept_embeddings, states), dim=-1)
118+
119+
# aggregate
120+
_neighbors_states = _neighbors_mask * agg_states
121+
122+
# self - aggregate
123+
_concept_embedding = get_states(inputs_i, states)
124+
_self_hidden_states = torch.cat((_concept_embedding, self.response_embedding(answer_i)), dim=-1)
125+
126+
_self_mask = F.one_hot(inputs_i, self.ku_num) # p
127+
_self_mask = expand_tensor(_self_mask, -1, self.hidden_num)
128+
129+
self_hidden_states = expand_tensor(_self_hidden_states, 1, self.ku_num)
130+
131+
# aggregate
132+
_hidden_states = torch.cat((_neighbors_states, self_hidden_states), dim=-1)
133+
134+
_in_state = self.n_in(_hidden_states)
135+
_out_state = self.n_out(_hidden_states)
136+
in_weight = expand_tensor(torch.Tensor(self.in_weight(inputs_i)), -1, self.hidden_num)
137+
out_weight = expand_tensor(torch.Tensor(self.out_weight(inputs_i)), -1, self.hidden_num)
138+
139+
next_neighbors_states = in_weight * _in_state + out_weight * _out_state
140+
141+
# self - update
142+
next_self_states = self.f_self(_self_hidden_states)
143+
next_self_states = expand_tensor(next_self_states, 1, self.ku_num)
144+
next_self_states = _self_mask * next_self_states
145+
146+
next_states = neighbors_mask * next_neighbors_states + next_self_states
147+
148+
next_states, _ = self.rnn(next_states, [states])
149+
next_states = (_self_mask + neighbors_mask) * next_states + (1 - _self_mask - neighbors_mask) * states
150+
151+
states = self.dropout(next_states)
152+
output = torch.sigmoid(self.out(states).squeeze(axis=-1)) # p
153+
outputs.append(output)
154+
if valid_length is not None and not compressed_out:
155+
all_states.append([states])
156+
157+
if valid_length is not None:
158+
if compressed_out:
159+
states = None
160+
outputs = mask_sequence_variable_length(torch, outputs, length, valid_length, axis, merge=True)
161+
162+
return outputs, states

EduKTM/GKT/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# coding: utf-8
2+
# 2022/2/25 @ fannazya
3+
4+
5+
from .GKT import GKT
6+
from .etl import etl

EduKTM/GKT/etl.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# coding: utf-8
2+
# 2022/2/25 @ fannazya
3+
4+
5+
import torch
6+
import json
7+
from tqdm import tqdm
8+
from EduKTM.utils.torch import PadSequence, FixedBucketSampler
9+
10+
11+
def extract(data_src, max_step=200): # pragma: no cover
12+
responses = []
13+
step = max_step
14+
with open(data_src) as f:
15+
for line in tqdm(f, "reading data from %s" % data_src):
16+
data = json.loads(line)
17+
if step is not None:
18+
for i in range(0, len(data), step):
19+
if len(data[i: i + step]) < 2:
20+
continue
21+
responses.append(data[i: i + step])
22+
else:
23+
responses.append(data)
24+
25+
return responses
26+
27+
28+
def transform(raw_data, batch_size, num_buckets=100):
29+
# 定义数据转换接口
30+
# raw_data --> batch_data
31+
32+
responses = raw_data
33+
34+
batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
35+
batch = []
36+
37+
def index(r):
38+
correct = 0 if r[1] <= 0 else 1
39+
return r[0] * 2 + correct
40+
41+
for batch_idx in tqdm(batch_idxes, "batchify"):
42+
batch_qs = []
43+
batch_rs = []
44+
batch_pick_index = []
45+
batch_labels = []
46+
for idx in batch_idx:
47+
batch_qs.append([r[0] for r in responses[idx]])
48+
batch_rs.append([index(r) for r in responses[idx]])
49+
if len(responses[idx]) <= 1: # pragma: no cover
50+
pick_index, labels = [], []
51+
else:
52+
pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]])
53+
batch_pick_index.append(list(pick_index))
54+
batch_labels.append(list(labels))
55+
56+
max_len = max([len(rs) for rs in batch_rs])
57+
padder = PadSequence(max_len, pad_val=0)
58+
batch_qs = [padder(qs) for qs in batch_qs]
59+
batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])
60+
61+
max_len = max([len(rs) for rs in batch_labels])
62+
padder = PadSequence(max_len, pad_val=0)
63+
batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels])
64+
batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index]
65+
# Load
66+
batch.append(
67+
[torch.tensor(batch_qs), torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels),
68+
torch.tensor(batch_pick_index),
69+
torch.tensor(label_mask)])
70+
71+
return batch
72+
73+
74+
def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover
75+
batch_size = batch_size if batch_size is not None else cfg.batch_size
76+
raw_data = extract(data_src)
77+
return transform(raw_data, batch_size, **kwargs)

EduKTM/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .DKTPlus import DKTPlus
99
from .AKT import AKT
1010
from .LPKT import LPKT
11+
from .GKT import GKT

EduKTM/utils/torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
from .extlib import *
55
from .functional import *
6+
from .rnn import *

0 commit comments

Comments
 (0)