-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdata_process.py
More file actions
123 lines (95 loc) · 3.85 KB
/
data_process.py
File metadata and controls
123 lines (95 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
# @Time : 2020-02-28 11:01
# @Author : WenYi
# @Contact : 1244058349@qq.com
# @Description : data_process
import argparse
import numpy as np
def read_item_index_to_entity_id_file():
file = './MKR-data/item_index2entity_id.txt'
print('reading item index to entity id file: ' + file + ' ...')
i = 0
for line in open(file, encoding='utf-8').readlines():
item_index = line.strip().split('\t')[0]
satori_id = line.strip().split('\t')[1]
item_index_old2new[item_index] = i
entity_id2index[satori_id] = i
i += 1
def convert_rating():
file = './MKR-data/BX-Book-Ratings.csv'
print('reading rating file ...')
item_set = set(item_index_old2new.values())
user_pos_ratings = dict()
user_neg_ratings = dict()
for line in open(file, encoding='utf-8').readlines()[1:]:
array = line.strip().split(';')
# remove prefix and suffix quotation marks for BX dataset
array = list(map(lambda x: x[1:-1], array))
item_index_old = array[1]
if item_index_old not in item_index_old2new: # the item is not in the final item set
continue
item_index = item_index_old2new[item_index_old]
user_index_old = int(array[0])
rating = float(array[2])
if rating >= 0:
if user_index_old not in user_pos_ratings:
user_pos_ratings[user_index_old] = set()
user_pos_ratings[user_index_old].add(item_index)
else:
if user_index_old not in user_neg_ratings:
user_neg_ratings[user_index_old] = set()
user_neg_ratings[user_index_old].add(item_index)
print('converting rating file ...')
writer = open('./MKR-data/ratings_final.txt', 'w', encoding='utf-8')
user_cnt = 0
user_index_old2new = dict()
for user_index_old, pos_item_set in user_pos_ratings.items():
if user_index_old not in user_index_old2new:
user_index_old2new[user_index_old] = user_cnt
user_cnt += 1
user_index = user_index_old2new[user_index_old]
for item in pos_item_set:
writer.write('%d\t%d\t1\n' % (user_index, item))
unwatched_set = item_set - pos_item_set
if user_index_old in user_neg_ratings:
unwatched_set -= user_neg_ratings[user_index_old]
for item in np.random.choice(list(unwatched_set), size=len(pos_item_set), replace=False):
writer.write('%d\t%d\t0\n' % (user_index, item))
writer.close()
print('number of users: %d' % user_cnt)
print('number of items: %d' % len(item_set))
def convert_kg():
print('converting kg.txt file ...')
entity_cnt = len(entity_id2index)
relation_cnt = 0
writer = open('./MKR-data/kg_final.txt', 'w', encoding='utf-8')
file = open('./MKR-data/kg.txt', encoding='utf-8')
for line in file:
array = line.strip().split('\t')
head_old = array[0]
relation_old = array[1]
tail_old = array[2]
if head_old not in entity_id2index:
continue
head = entity_id2index[head_old]
if tail_old not in entity_id2index:
entity_id2index[tail_old] = entity_cnt
entity_cnt += 1
tail = entity_id2index[tail_old]
if relation_old not in relation_id2index:
relation_id2index[relation_old] = relation_cnt
relation_cnt += 1
relation = relation_id2index[relation_old]
writer.write('%d\t%d\t%d\n' % (head, relation, tail))
writer.close()
print('number of entities (containing items): %d' % entity_cnt)
print('number of relations: %d' % relation_cnt)
if __name__ == '__main__':
np.random.seed(555)
entity_id2index = dict()
relation_id2index = dict()
item_index_old2new = dict()
read_item_index_to_entity_id_file()
convert_rating()
convert_kg()
print('done')