@@ -117,9 +117,10 @@ public float score(int node) throws IOException {
117
117
}
118
118
119
119
@ Override
120
- public void bulkScore (int [] nodes , float [] scores , int numNodes ) throws IOException {
120
+ public float bulkScore (int [] nodes , float [] scores , int numNodes ) throws IOException {
121
121
// TODO checkOrdinal(node1 ....);
122
122
float [] scratchScores = new float [4 ];
123
+ float maxScore = Float .NEGATIVE_INFINITY ;
123
124
int i = 0 ;
124
125
MemorySegment query = getSegment (queryOrd , queryScratch );
125
126
final int limit = numNodes & ~3 ;
@@ -130,9 +131,13 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
130
131
MemorySegment ms4 = getSegment (nodes [i + 3 ], scratch4 );
131
132
DOT_OPS .dotProductBulk (scratchScores , query , ms1 , ms2 , ms3 , ms4 , dims );
132
133
scores [i + 0 ] = normalizeToUnitInterval (scratchScores [0 ]);
134
+ maxScore = Math .max (maxScore , scores [i + 0 ]);
133
135
scores [i + 1 ] = normalizeToUnitInterval (scratchScores [1 ]);
136
+ maxScore = Math .max (maxScore , scores [i + 1 ]);
134
137
scores [i + 2 ] = normalizeToUnitInterval (scratchScores [2 ]);
138
+ maxScore = Math .max (maxScore , scores [i + 2 ]);
135
139
scores [i + 3 ] = normalizeToUnitInterval (scratchScores [3 ]);
140
+ maxScore = Math .max (maxScore , scores [i + 3 ]);
136
141
}
137
142
// Handle remaining 1–3 nodes in bulk (if any)
138
143
int remaining = numNodes - i ;
@@ -142,9 +147,17 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
142
147
MemorySegment ms3 = (remaining > 2 ) ? getSegment (nodes [i + 2 ], scratch3 ) : ms1 ;
143
148
DOT_OPS .dotProductBulk (scratchScores , query , ms1 , ms2 , ms3 , ms1 , dims );
144
149
scores [i ] = normalizeToUnitInterval (scratchScores [0 ]);
145
- if (remaining > 1 ) scores [i + 1 ] = normalizeToUnitInterval (scratchScores [1 ]);
146
- if (remaining > 2 ) scores [i + 2 ] = normalizeToUnitInterval (scratchScores [2 ]);
150
+ maxScore = Math .max (maxScore , scores [i ]);
151
+ if (remaining > 1 ) {
152
+ scores [i + 1 ] = normalizeToUnitInterval (scratchScores [1 ]);
153
+ maxScore = Math .max (maxScore , scores [i + 1 ]);
154
+ }
155
+ if (remaining > 2 ) {
156
+ scores [i + 2 ] = normalizeToUnitInterval (scratchScores [2 ]);
157
+ maxScore = Math .max (maxScore , scores [i + 2 ]);
158
+ }
147
159
}
160
+ return maxScore ;
148
161
}
149
162
150
163
@ Override
0 commit comments