Skip to content

Commit 8dee377

Browse files
committed
reconcile with new bulkScore implementations
1 parent dc752ce commit 8dee377

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ public float score(int node) throws IOException {
112112
}
113113

114114
@Override
115-
public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
115+
public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
116116
float[] scratchScores = new float[4];
117+
float maxScore = Float.NEGATIVE_INFINITY;
117118
int i = 0;
118119
final int limit = numNodes & ~3;
119120
for (; i < limit; i += 4) {
@@ -123,9 +124,13 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
123124
MemorySegment ms4 = getSegment(nodes[i + 3]);
124125
DOT_OPS.dotProductBulk(scratchScores, query, ms1, ms2, ms3, ms4, query.length);
125126
scores[i + 0] = normalizeToUnitInterval(scratchScores[0]);
127+
maxScore = Math.max(maxScore, scores[i + 0]);
126128
scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
129+
maxScore = Math.max(maxScore, scores[i + 1]);
127130
scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
131+
maxScore = Math.max(maxScore, scores[i + 2]);
128132
scores[i + 3] = normalizeToUnitInterval(scratchScores[3]);
133+
maxScore = Math.max(maxScore, scores[i + 3]);
129134
}
130135
// Handle remaining 1–3 nodes in bulk (if any)
131136
int remaining = numNodes - i;
@@ -135,9 +140,17 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
135140
MemorySegment ms3 = (remaining > 2) ? getSegment(nodes[i + 2]) : ms1;
136141
DOT_OPS.dotProductBulk(scratchScores, query, ms1, ms2, ms3, ms1, query.length);
137142
scores[i] = normalizeToUnitInterval(scratchScores[0]);
138-
if (remaining > 1) scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
139-
if (remaining > 2) scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
143+
maxScore = Math.max(maxScore, scores[i]);
144+
if (remaining > 1) {
145+
scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
146+
maxScore = Math.max(maxScore, scores[i + 1]);
147+
}
148+
if (remaining > 2) {
149+
scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
150+
maxScore = Math.max(maxScore, scores[i + 2]);
151+
}
140152
}
153+
return maxScore;
141154
}
142155
}
143156
}

lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ public float score(int node) throws IOException {
117117
}
118118

119119
@Override
120-
public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
120+
public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
121121
// TODO checkOrdinal(node1 ....);
122122
float[] scratchScores = new float[4];
123+
float maxScore = Float.NEGATIVE_INFINITY;
123124
int i = 0;
124125
MemorySegment query = getSegment(queryOrd, queryScratch);
125126
final int limit = numNodes & ~3;
@@ -130,9 +131,13 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
130131
MemorySegment ms4 = getSegment(nodes[i + 3], scratch4);
131132
DOT_OPS.dotProductBulk(scratchScores, query, ms1, ms2, ms3, ms4, dims);
132133
scores[i + 0] = normalizeToUnitInterval(scratchScores[0]);
134+
maxScore = Math.max(maxScore, scores[i + 0]);
133135
scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
136+
maxScore = Math.max(maxScore, scores[i + 1]);
134137
scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
138+
maxScore = Math.max(maxScore, scores[i + 2]);
135139
scores[i + 3] = normalizeToUnitInterval(scratchScores[3]);
140+
maxScore = Math.max(maxScore, scores[i + 3]);
136141
}
137142
// Handle remaining 1–3 nodes in bulk (if any)
138143
int remaining = numNodes - i;
@@ -142,9 +147,17 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
142147
MemorySegment ms3 = (remaining > 2) ? getSegment(nodes[i + 2], scratch3) : ms1;
143148
DOT_OPS.dotProductBulk(scratchScores, query, ms1, ms2, ms3, ms1, dims);
144149
scores[i] = normalizeToUnitInterval(scratchScores[0]);
145-
if (remaining > 1) scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
146-
if (remaining > 2) scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
150+
maxScore = Math.max(maxScore, scores[i]);
151+
if (remaining > 1) {
152+
scores[i + 1] = normalizeToUnitInterval(scratchScores[1]);
153+
maxScore = Math.max(maxScore, scores[i + 1]);
154+
}
155+
if (remaining > 2) {
156+
scores[i + 2] = normalizeToUnitInterval(scratchScores[2]);
157+
maxScore = Math.max(maxScore, scores[i + 2]);
158+
}
147159
}
160+
return maxScore;
148161
}
149162

150163
@Override

0 commit comments

Comments
 (0)