diff --git a/lucene/core/src/java/org/apache/lucene/search/BlockMaxConjunctionBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/BlockMaxConjunctionBulkScorer.java index 21f8af990b85..48d5ef8bdbe1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BlockMaxConjunctionBulkScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/BlockMaxConjunctionBulkScorer.java @@ -187,7 +187,7 @@ private void scoreWindowScoreFirst( docAndScoreAccBuffer, sumOfOtherClause, scorable.minCompetitiveScore, scorers.length); } - ScorerUtil.applyRequiredClause(docAndScoreAccBuffer, iterators[i], scorables[i]); + scorers[i].applyAsRequiredClause(docAndScoreAccBuffer); } for (int i = 0; i < docAndScoreAccBuffer.size; ++i) { diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreScorer.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreScorer.java index 4ae8ef09017f..a24a14da906c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreScorer.java @@ -163,4 +163,28 @@ public void nextDocsAndScores(int upTo, Bits liveDocs, DocAndFloatFeatureBuffer Arrays.fill(buffer.features, 0, size, score); buffer.size = size; } + + @Override + public void applyAsRequiredClause(DocAndScoreAccBuffer buffer) throws IOException { + int intersectionSize = 0; + int curDoc = disi.docID(); + for (int i = 0; i < buffer.size; ++i) { + int targetDoc = buffer.docs[i]; + if (curDoc < targetDoc) { + curDoc = disi.advance(targetDoc); + } + if (curDoc == targetDoc) { + buffer.docs[intersectionSize] = targetDoc; + buffer.scores[intersectionSize] = buffer.scores[i]; + intersectionSize++; + } + } + + buffer.size = intersectionSize; + if (score != 0) { + for (int i = 0; i < intersectionSize; ++i) { + buffer.scores[i] += score; + } + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java index c723cbe00f13..9290e418d7ca 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java @@ -262,8 +262,7 @@ private void scoreInnerWindowAsConjunction(LeafCollector collector, Bits acceptD allScorers.length); } - DisiWrapper scorer = allScorers[i]; - ScorerUtil.applyRequiredClause(docAndScoreAccBuffer, scorer.iterator, scorer.scorable); + allScorers[i].scorer.applyAsRequiredClause(docAndScoreAccBuffer); } scoreNonEssentialClauses(collector, docAndScoreAccBuffer, firstRequiredScorer); diff --git a/lucene/core/src/java/org/apache/lucene/search/Scorer.java b/lucene/core/src/java/org/apache/lucene/search/Scorer.java index fc540c30cc42..2e3bb0648544 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Scorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/Scorer.java @@ -130,4 +130,27 @@ public void nextDocsAndScores(int upTo, Bits liveDocs, DocAndFloatFeatureBuffer } buffer.size = size; } + + /** + * Apply this {@link Scorer} as a required clause on the given {@link DocAndScoreAccBuffer}. This + * filters out documents from the buffer that do not match this scorer, and adds the scores of + * this {@link Scorer} to the scores. + */ + public void applyAsRequiredClause(DocAndScoreAccBuffer buffer) throws IOException { + DocIdSetIterator iterator = iterator(); + int intersectionSize = 0; + int curDoc = iterator.docID(); + for (int i = 0; i < buffer.size; ++i) { + int targetDoc = buffer.docs[i]; + if (curDoc < targetDoc) { + curDoc = iterator.advance(targetDoc); + } + if (curDoc == targetDoc) { + buffer.docs[intersectionSize] = targetDoc; + buffer.scores[intersectionSize] = buffer.scores[i] + score(); + intersectionSize++; + } + } + buffer.size = intersectionSize; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/TermScorer.java b/lucene/core/src/java/org/apache/lucene/search/TermScorer.java index 46cf17534410..b05c7d2fc1e4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermScorer.java @@ -26,6 +26,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.LongsRef; /** @@ -41,6 +42,7 @@ public final class TermScorer extends Scorer { private final NumericDocValues norms; private final ImpactsDISI impactsDisi; private final MaxScoreCache maxScoreCache; + private int[] freqs = IntsRef.EMPTY_INTS; private long[] normValues = LongsRef.EMPTY_LONGS; /** Construct a {@link TermScorer} that will iterate all documents. */ @@ -171,4 +173,41 @@ public void nextDocsAndScores(int upTo, Bits liveDocs, DocAndFloatFeatureBuffer bulkScorer.score(buffer.size, buffer.features, normValues, buffer.features); } + + @Override + public void applyAsRequiredClause(DocAndScoreAccBuffer buffer) throws IOException { + int size = buffer.size; + if (freqs.length < size) { + freqs = ArrayUtil.growNoCopy(freqs, size); + normValues = new long[freqs.length]; + } + + int intersectionSize = 0; + int curDoc = docID(); + for (int i = 0; i < size; i++) { + int targetDoc = buffer.docs[i]; + if (curDoc < targetDoc) { + curDoc = postingsEnum.advance(targetDoc); + } + if (curDoc == targetDoc) { + buffer.docs[intersectionSize] = targetDoc; + buffer.scores[intersectionSize] = buffer.scores[i]; + freqs[intersectionSize] = postingsEnum.freq(); + intersectionSize++; + } + } + buffer.size = intersectionSize; + + for (int i = 0; i < intersectionSize; i++) { + if (norms == null || norms.advanceExact(buffer.docs[i]) == false) { + normValues[i] = 1L; + } else { + normValues[i] = norms.longValue(); + } + } + + for (int i = 0; i < intersectionSize; i++) { + buffer.scores[i] += scorer.score(freqs[i], normValues[i]); + } + } }