22
33import numpy as np
44from gensim .models import Word2Vec
5- from numba import get_num_threads , jit , prange
6- from numba . np . ufunc . parallel import _get_thread_id
5+ from numba import jit , prange
6+ from numba_progress import ProgressBar
77from pecanpy .graph import DenseGraph , SparseGraph
88from pecanpy .wrappers import Timer
99
@@ -53,8 +53,7 @@ def __init__(self, p, q, workers, verbose, extend=False):
5353 a localized neighborhood.
5454 workers (int): number of threads to be spawned for runing node2vec
5555 including walk generation and word2vec embedding.
56- verbose (bool): (not implemented yet due to issue with numba jit)
57- whether or not to display walk generation progress.
56+ verbose (bool): show progress bar for walk generation.
5857 extend (bool): ``True`` if use node2vec+ extension, default is ``False``
5958
6059 """
@@ -65,6 +64,19 @@ def __init__(self, p, q, workers, verbose, extend=False):
6564 self .verbose = verbose
6665 self .extend = extend
6766
67+ def _map_walk (self , walk_idx_ary ):
68+ """Map walk from node index to node ID.
69+
70+ Note:
71+ The last element in the ``walk_idx_ary`` encodes the effective walk
72+ length. Only walk indices up to the effective walk length are
73+ translated (mapped to node IDs).
74+
75+ """
76+ end_idx = walk_idx_ary [- 1 ]
77+ walk = [self .IDlst [i ] for i in walk_idx_ary [:end_idx ]]
78+ return walk
79+
6880 def simulate_walks (self , num_walks , walk_length , n_ckpts , pb_len ):
6981 """Generate walks starting from each nodes ``num_walks`` time.
7082
@@ -84,27 +96,22 @@ def simulate_walks(self, num_walks, walk_length, n_ckpts, pb_len):
8496 nodes = np .array (range (num_nodes ), dtype = np .uint32 )
8597 start_node_idx_ary = np .concatenate ([nodes ] * num_walks )
8698 np .random .shuffle (start_node_idx_ary )
99+ tot_num_jobs = start_node_idx_ary .size
87100
88101 move_forward = self .get_move_forward ()
89102 has_nbrs = self .get_has_nbrs ()
90103 verbose = self .verbose
91104
92105 @jit (parallel = True , nogil = True , nopython = True )
93- def node2vec_walks ():
106+ def node2vec_walks (num_iter , progress_proxy ):
94107 """Simulate a random walk starting from start node."""
95- tot_num_jobs = start_node_idx_ary .size
96108 # use the last entry of each walk index array to keep track of the
97109 # effective walk length
98- walk_idx_mat = np .zeros ((tot_num_jobs , walk_length + 2 ), dtype = np .uint32 )
110+ walk_idx_mat = np .zeros ((num_iter , walk_length + 2 ), dtype = np .uint32 )
99111 walk_idx_mat [:, 0 ] = start_node_idx_ary # initialize seeds
100112 walk_idx_mat [:, - 1 ] = walk_length + 1 # set to full walk length by default
101113
102- # progress bar parameters
103- num_threads = get_num_threads ()
104- checkpoint = tot_num_jobs / num_threads // n_ckpts
105- private_count = 0
106-
107- for i in prange (tot_num_jobs ):
114+ for i in prange (num_iter ):
108115 # initialize first step as normal random walk
109116 start_node_idx = walk_idx_mat [i , 0 ]
110117 if has_nbrs (start_node_idx ):
@@ -123,23 +130,16 @@ def node2vec_walks():
123130 walk_idx_mat [i , - 1 ] = j
124131 break
125132
126- if verbose :
127- thread_id = _get_thread_id ()
128- private_count += 1
129- progress_log (
130- tot_num_jobs ,
131- private_count ,
132- checkpoint ,
133- pb_len ,
134- num_threads ,
135- thread_id ,
136- )
133+ progress_proxy .update (1 )
137134
138135 return walk_idx_mat
139136
140- walks = [
141- [self .IDlst [idx ] for idx in walk [: walk [- 1 ]]] for walk in node2vec_walks ()
142- ]
137+ # Acquire numba progress proxy for displaying the progress bar
138+ with ProgressBar (total = tot_num_jobs , disable = not verbose ) as progress :
139+ walk_idx_mat = node2vec_walks (tot_num_jobs , progress )
140+
141+ # Map node index back to node ID
142+ walks = [self ._map_walk (walk_idx_ary ) for walk_idx_ary in walk_idx_mat ]
143143
144144 return walks
145145
@@ -464,50 +464,6 @@ def move_forward(cur_idx, prev_idx=None):
464464 return move_forward
465465
466466
467- @jit (nopython = True , nogil = True )
468- def progress_log (
469- tot_num_jobs ,
470- curr_iter ,
471- checkpoint ,
472- progress_bar_length ,
473- num_threads ,
474- thread_id ,
475- ):
476- """Monitor the progress of random walk generation.
477-
478- Manually construct the progress bar for the current thread and print.
479-
480- Args:
481- tot_num_jobs (int): total number of jobs
482- curr_iter (int): current iteration number.
483- checkpoint (int): intervals for reporting progress.
484- progress_bar_length (int): full length of the progress bar
485- num_threads (int): total number of threads
486- thread_id (int): id of the current thread
487-
488- """
489- # TODO: make monitoring less messy, i.e. flush line
490- if curr_iter % checkpoint == 0 :
491- progress = (
492- curr_iter / tot_num_jobs * progress_bar_length * num_threads
493- )
494-
495- # manuually construct progress bar since fstring not supported
496- progress_bar = "|"
497- for k in range (progress_bar_length ):
498- progress_bar += "#" if k < progress else " "
499- progress_bar += "|"
500-
501- print (
502- "Thread # " if thread_id < 10 else "Thread #" ,
503- thread_id ,
504- "progress:" ,
505- progress_bar ,
506- num_threads * curr_iter * 10000 // tot_num_jobs / 100 ,
507- "%" ,
508- )
509-
510-
511467@jit (nopython = True , nogil = True )
512468def alias_setup (probs ):
513469 """Construct alias lookup table.
0 commit comments