@@ -98,10 +98,8 @@ def quantile_or_topk(
98
98
param = np .atleast_1d (param )
99
99
param = np .reshape (param , (param .size ,) + (1 ,) * array .ndim )
100
100
101
- if is_scalar_param :
102
- idxshape = array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
103
- else :
104
- idxshape = (param .shape [0 ],) + array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
101
+ # For topk(.., k=+1 or -1), we always return the singleton dimension.
102
+ idxshape = (param .shape [0 ],) + array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
105
103
106
104
if q is not None :
107
105
# This is numpy's method="linear"
@@ -110,6 +108,7 @@ def quantile_or_topk(
110
108
111
109
if is_scalar_param :
112
110
virtual_index = virtual_index .squeeze (axis = 0 )
111
+ idxshape = array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
113
112
114
113
lo_ = np .floor (
115
114
virtual_index , casting = "unsafe" , out = np .empty (virtual_index .shape , dtype = np .int64 )
@@ -122,7 +121,7 @@ def quantile_or_topk(
122
121
else :
123
122
virtual_index = inv_idx [:- 1 ] + ((actual_sizes - k ) if k > 0 else abs (k ) - 1 )
124
123
kth = np .unique (virtual_index )
125
- kth = kth [kth > 0 ]
124
+ kth = kth [kth >= 0 ]
126
125
k_offset = param .reshape ((abs (k ),) + (1 ,) * virtual_index .ndim )
127
126
lo_ = k_offset + virtual_index [np .newaxis , ...]
128
127
@@ -147,12 +146,18 @@ def quantile_or_topk(
147
146
result = _lerp (loval , hival , t = gamma , out = out , dtype = dtype )
148
147
else :
149
148
result = loval
150
- result [lo_ < 0 ] = fill_value
149
+ # This happens if numel in group < abs(k)
150
+ badmask = lo_ < 0
151
+ if badmask .any ():
152
+ result [badmask ] = fill_value
153
+
151
154
if not skipna and np .any (nanmask ):
152
155
result [..., nanmask ] = fill_value
156
+
153
157
if k is not None :
154
158
result = result .astype (dtype , copy = False )
155
- np .copyto (out , result )
159
+ if out is not None :
160
+ np .copyto (out , result )
156
161
return result
157
162
158
163
0 commit comments