@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
1515 auto num_neighbors = row_end - row_start;
1616
1717 int64_t size = count;
18- if (count < 1 ) {
18+ if (count < 1 )
1919 size = int64_t (ceil (factor * float (num_neighbors)));
20- }
20+ if (size > num_neighbors)
21+ size = num_neighbors;
2122
2223 // If the number of neighbors is approximately equal to the number of
2324 // neighbors which are requested, we use `randperm` to sample without
@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
2627 std::unordered_set<int64_t > set;
2728 if (size < 0.7 * float (num_neighbors)) {
2829 while (int64_t (set.size ()) < size) {
29- int64_t sample = ( rand () % num_neighbors) + row_start ;
30- set.insert (sample);
30+ int64_t sample = rand () % num_neighbors;
31+ set.insert (sample + row_start );
3132 }
3233 std::vector<int64_t > v (set.begin (), set.end ());
3334 e_ids.insert (e_ids.end (), v.begin (), v.end ());
3435 } else {
35- auto sample = at ::randperm (num_neighbors, start.options ()) + row_start ;
36+ auto sample = torch ::randperm (num_neighbors, start.options ());
3637 auto sample_data = sample.data_ptr <int64_t >();
3738 for (auto j = 0 ; j < size; j++) {
38- e_ids.push_back (sample_data[j]);
39+ e_ids.push_back (sample_data[j] + row_start );
3940 }
4041 }
4142 }
0 commit comments