@@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
137137
138138 return std::make_tuple (n_out, e_out);
139139}
140+
141+
142+ void compute_cdf (const int64_t *rowptr, const float_t *edge_weight,
143+ float_t *edge_weight_cdf, int64_t numel) {
144+ /* Convert edge weights to CDF as given in [1]
145+
146+ [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
147+ */
148+ at::parallel_for (0 , numel - 1 , at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
149+ for (int64_t i = begin; i < end; i++) {
150+ int64_t row_start = rowptr[i], row_end = rowptr[i + 1 ];
151+ float_t acc = 0.0 ;
152+
153+ for (int64_t j = row_start; j < row_end; j++) {
154+ acc += edge_weight[j];
155+ edge_weight_cdf[j] = acc;
156+ }
157+ }
158+ });
159+ }
160+
161+
162+ int64_t get_offset (const float_t *edge_weight, int64_t start, int64_t end) {
163+ /*
164+ The implementation given in [1] utilizes the `searchsorted` function in Numpy.
165+ It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
166+ However, the implementation is adopted to the general case where the searched
167+ values can be a multidimensional tensor. In our case, we have a 1D tensor of
168+ edge weights (in form of a Cumulative Distribution Function) and a single
169+ value, whose position we want to compute. To eliminate the overhead introduced
170+ in the PyTorch implementation, one can examine the source code of
171+ `searchsorted` [2] and find that for our case the whole function call can be
172+ reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
173+ access it directly (the namespace is not exposed to the public API), but the
174+ implementation is just a simple binary search. The code was copied here and
175+ reduced to the bare minimum.
176+
177+ [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
178+ [2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
179+ */
180+ float_t value = ((float_t )rand () / RAND_MAX); // [0, 1)
181+ int64_t original_start = start;
182+
183+ while (start < end) {
184+ const int64_t mid = start + ((end - start) >> 1 );
185+ const float_t mid_val = edge_weight[mid];
186+ if (!(mid_val >= value)) {
187+ start = mid + 1 ;
188+ }
189+ else {
190+ end = mid;
191+ }
192+ }
193+
194+ return start - original_start;
195+ }
196+
197+ // See: https://louisabraham.github.io/articles/node2vec-sampling.html
198+ // See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
199+ void rejection_sampling_weighted (const int64_t *rowptr, const int64_t *col,
200+ const float_t *edge_weight_cdf, int64_t *start,
201+ int64_t *n_out, int64_t *e_out,
202+ const int64_t numel, const int64_t walk_length,
203+ const double p, const double q) {
204+
205+ double max_prob = fmax (fmax (1 . / p, 1 .), 1 . / q);
206+ double prob_0 = 1 . / p / max_prob;
207+ double prob_1 = 1 . / max_prob;
208+ double prob_2 = 1 . / q / max_prob;
209+
210+ int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
211+ at::parallel_for (0 , numel, grain_size, [&](int64_t begin, int64_t end) {
212+ for (auto n = begin; n < end; n++) {
213+ int64_t t = start[n], v, x, e_cur, row_start, row_end;
214+
215+ n_out[n * (walk_length + 1 )] = t;
216+
217+ row_start = rowptr[t], row_end = rowptr[t + 1 ];
218+
219+ if (row_end - row_start == 0 ) {
220+ e_cur = -1 ;
221+ v = t;
222+ } else {
223+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
224+ v = col[e_cur];
225+ }
226+ n_out[n * (walk_length + 1 ) + 1 ] = v;
227+ e_out[n * walk_length] = e_cur;
228+
229+ for (auto l = 1 ; l < walk_length; l++) {
230+ row_start = rowptr[v], row_end = rowptr[v + 1 ];
231+
232+ if (row_end - row_start == 0 ) {
233+ e_cur = -1 ;
234+ x = v;
235+ } else if (row_end - row_start == 1 ) {
236+ e_cur = row_start;
237+ x = col[e_cur];
238+ } else {
239+ if (p == 1 and q == 1 ) {
240+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
241+ x = col[e_cur];
242+ }
243+ else {
244+ while (true ) {
245+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
246+ x = col[e_cur];
247+
248+ auto r = ((double )rand () / (RAND_MAX)); // [0, 1)
249+
250+ if (x == t && r < prob_0)
251+ break ;
252+ else if (is_neighbor (rowptr, col, x, t) && r < prob_1)
253+ break ;
254+ else if (r < prob_2)
255+ break ;
256+ }
257+ }
258+ }
259+
260+ n_out[n * (walk_length + 1 ) + (l + 1 )] = x;
261+ e_out[n * walk_length + l] = e_cur;
262+ t = v;
263+ v = x;
264+ }
265+ }
266+ });
267+ }
268+
269+
270+ std::tuple<torch::Tensor, torch::Tensor>
271+ random_walk_weighted_cpu (torch::Tensor rowptr, torch::Tensor col,
272+ torch::Tensor edge_weight, torch::Tensor start,
273+ int64_t walk_length, double p, double q) {
274+ CHECK_CPU (rowptr);
275+ CHECK_CPU (col);
276+ CHECK_CPU (edge_weight);
277+ CHECK_CPU (start);
278+
279+ CHECK_INPUT (rowptr.dim () == 1 );
280+ CHECK_INPUT (col.dim () == 1 );
281+ CHECK_INPUT (edge_weight.dim () == 1 );
282+ CHECK_INPUT (start.dim () == 1 );
283+
284+ auto n_out = torch::empty ({start.size (0 ), walk_length + 1 }, start.options ());
285+ auto e_out = torch::empty ({start.size (0 ), walk_length}, start.options ());
286+
287+ auto rowptr_data = rowptr.data_ptr <int64_t >();
288+ auto col_data = col.data_ptr <int64_t >();
289+ auto edge_weight_data = edge_weight.data_ptr <float_t >();
290+ auto start_data = start.data_ptr <int64_t >();
291+ auto n_out_data = n_out.data_ptr <int64_t >();
292+ auto e_out_data = e_out.data_ptr <int64_t >();
293+
294+ auto edge_weight_cdf = torch::empty ({edge_weight.size (0 )}, edge_weight.options ());
295+ auto edge_weight_cdf_data = edge_weight_cdf.data_ptr <float_t >();
296+
297+ compute_cdf (rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel ());
298+
299+ rejection_sampling_weighted (rowptr_data, col_data, edge_weight_cdf_data,
300+ start_data, n_out_data, e_out_data, start.numel (),
301+ walk_length, p, q);
302+
303+ return std::make_tuple (n_out, e_out);
304+ }
0 commit comments