Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions graphsage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@
import json
import sys
import os

import pdb
import networkx as nx
from networkx.readwrite import json_graph
version_info = list(map(int, nx.__version__.split('.')))
major = version_info[0]
minor = version_info[1]
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"

WALK_LEN=5
N_WALKS=50

def load_data(prefix, normalize=True, load_walks=False):
G_data = json.load(open(prefix + "-G.json"))
G = json_graph.node_link_graph(G_data)
if isinstance(G.nodes()[0], int):
conversion = lambda n : int(n)
else:
try:
if isinstance(G.nodes()[0], dict):
conversion = lambda n : int(n)
else:
print("Something wrong when load graph")
except:
conversion = lambda n : n

if os.path.exists(prefix + "-feats.npy"):
Expand All @@ -44,7 +43,7 @@ def load_data(prefix, normalize=True, load_walks=False):
## (necessary because of networkx weirdness with the Reddit data)
broken_count = 0
for node in G.nodes():
if not 'val' in G.node[node] or not 'test' in G.node[node]:
if not 'val' in G.nodes[node] or not 'test' in G.nodes[node]:
G.remove_node(node)
broken_count += 1
print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
Expand All @@ -53,15 +52,15 @@ def load_data(prefix, normalize=True, load_walks=False):
## (some datasets might already have this..)
print("Loaded data.. now preprocessing..")
for edge in G.edges():
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
if (G.nodes[edge[0]]['val'] or G.nodes[edge[1]]['val'] or
G.nodes[edge[0]]['test'] or G.nodes[edge[1]]['test']):
G[edge[0]][edge[1]]['train_removed'] = True
else:
G[edge[0]][edge[1]]['train_removed'] = False

if normalize and not feats is None:
from sklearn.preprocessing import StandardScaler
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
train_ids = np.array([id_map[n] for n in G.nodes() if not G.nodes[n]['val'] and not G.nodes[n]['test']])
train_feats = feats[train_ids]
scaler = StandardScaler()
scaler.fit(train_feats)
Expand Down Expand Up @@ -97,7 +96,7 @@ def run_random_walks(G, nodes, num_walks=N_WALKS):
out_file = sys.argv[2]
G_data = json.load(open(graph_file))
G = json_graph.node_link_graph(G_data)
nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]]
nodes = [n for n in G.nodes() if not G.nodes[n]["val"] and not G.nodes[n]["test"]]
G = G.subgraph(nodes)
pairs = run_random_walks(G, nodes)
with open(out_file, "w") as fp:
Expand Down