@@ -110,16 +110,28 @@ def __init__(
110110 self .data = self .kdtree .data
111111 self .k = k
112112 self .p = p
113+ # these are both n x k+1
113114 distances , indices = self .kdtree .query (self .data , k = k + 1 , p = p )
114115 full_indices = np .arange (self .kdtree .n )
116+
117+ # if an element in the indices matrix is equal to the corresponding
118+ # index for that row, we want to mask that site from its neighbors
115119 not_self_mask = indices != full_indices .reshape (- 1 , 1 )
116- not_self_indices = indices [not_self_mask ].reshape (self .kdtree .n , k )
120+ # if there are *too many duplicates per site*, then we may get some
121+ # rows where the site index is not in the set of k+1 neighbors
122+ # So, we need to know where these sites are
123+ has_one_too_many = not_self_mask .sum (axis = 1 ) == (k + 1 )
124+ # if a site has k+1 neighbors, drop its k+1th neighbor
125+ not_self_mask [has_one_too_many , - 1 ] &= False
126+ not_self_indices = indices [not_self_mask ].reshape (self .kdtree .n , - 1 )
117127
118128 to_weight = not_self_indices
119129 if ids is None :
120130 ids = list (full_indices )
121-
122- neighbors = {idx : list (indices ) for idx , indices in zip (ids , not_self_indices )}
131+ named_indices = not_self_indices
132+ else :
133+ named_indices = np .asarray (ids )[not_self_indices ]
134+ neighbors = {idx : list (indices ) for idx , indices in zip (ids , named_indices )}
123135
124136 W .__init__ (self , neighbors , id_order = ids , ** kwargs )
125137
0 commit comments