Skip to content

Commit d0f71fc

Browse files
Stop instantiating RankFeaturePhase unnecessarily (elastic#115724)
We should not create the phase instance when we know we won't be doing any rank feature execution up-front. An instance of these isn't free and entails creating an array of searched_shard_count size which along is non-trivial. Also, this needlessly obscured the threading logic for fetch which has already led to a bug before.
1 parent 2522c98 commit d0f71fc

File tree

4 files changed

+33
-134
lines changed

4 files changed

+33
-134
lines changed

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ public class RankFeaturePhase extends SearchPhase {
4242
final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
4343
private final AggregatedDfs aggregatedDfs;
4444
private final SearchProgressListener progressListener;
45-
private final Client client;
45+
private final RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext;
4646

4747
RankFeaturePhase(
4848
SearchPhaseResults<SearchPhaseResult> queryPhaseResults,
4949
AggregatedDfs aggregatedDfs,
5050
SearchPhaseContext context,
51-
Client client
51+
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
5252
) {
5353
super("rank-feature");
54+
assert rankFeaturePhaseRankCoordinatorContext != null;
55+
this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext;
5456
if (context.getNumShards() != queryPhaseResults.getNumShards()) {
5557
throw new IllegalStateException(
5658
"number of shards must match the length of the query results but doesn't:"
@@ -65,17 +67,10 @@ public class RankFeaturePhase extends SearchPhase {
6567
this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards());
6668
context.addReleasable(rankPhaseResults);
6769
this.progressListener = context.getTask().getProgressListener();
68-
this.client = client;
6970
}
7071

7172
@Override
7273
public void run() {
73-
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
74-
if (rankFeaturePhaseRankCoordinatorContext == null) {
75-
moveToNextPhase(queryPhaseResults, null);
76-
return;
77-
}
78-
7974
context.execute(new AbstractRunnable() {
8075
@Override
8176
protected void doRun() throws Exception {
@@ -122,7 +117,7 @@ void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordin
122117
}
123118
}
124119

125-
private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
120+
static RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source, Client client) {
126121
return source == null || source.rankBuilder() == null
127122
? null
128123
: source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
@@ -175,7 +170,6 @@ private void onPhaseDone(
175170
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
176171
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
177172
) {
178-
assert rankFeaturePhaseRankCoordinatorContext != null;
179173
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() {
180174
@Override
181175
public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> res
105105
aggregatedDfs,
106106
mergedKnnResults,
107107
queryPhaseResultConsumer,
108-
(queryResults) -> new RankFeaturePhase(queryResults, aggregatedDfs, context, client),
108+
(queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResults, aggregatedDfs),
109109
context
110110
);
111111
}

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1919
import org.elasticsearch.search.SearchPhaseResult;
2020
import org.elasticsearch.search.SearchShardTarget;
21+
import org.elasticsearch.search.dfs.AggregatedDfs;
2122
import org.elasticsearch.search.internal.AliasFilter;
2223
import org.elasticsearch.search.internal.SearchContext;
2324
import org.elasticsearch.search.internal.ShardSearchRequest;
@@ -125,9 +126,22 @@ && getRequest().scroll() == null
125126
super.onShardResult(result, shardIt);
126127
}
127128

129+
static SearchPhase nextPhase(
130+
Client client,
131+
SearchPhaseContext context,
132+
SearchPhaseResults<SearchPhaseResult> queryResults,
133+
AggregatedDfs aggregatedDfs
134+
) {
135+
var rankFeaturePhaseCoordCtx = RankFeaturePhase.coordinatorContext(context.getRequest().source(), client);
136+
if (rankFeaturePhaseCoordCtx == null) {
137+
return new FetchSearchPhase(queryResults, aggregatedDfs, context, null);
138+
}
139+
return new RankFeaturePhase(queryResults, aggregatedDfs, context, rankFeaturePhaseCoordCtx);
140+
}
141+
128142
@Override
129143
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
130-
return new RankFeaturePhase(results, null, this, client);
144+
return nextPhase(client, this, results, null);
131145
}
132146

133147
private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

Lines changed: 12 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -287,95 +287,6 @@ public void sendExecuteRankFeature(
287287
}
288288
}
289289

290-
public void testRankFeaturePhaseNoNeedForFetchingFieldData() {
291-
AtomicBoolean phaseDone = new AtomicBoolean(false);
292-
final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
293-
294-
// build the appropriate RankBuilder; using a null rankFeaturePhaseRankShardContext
295-
// and non-field based rankFeaturePhaseRankCoordinatorContext
296-
RankBuilder rankBuilder = rankBuilder(
297-
DEFAULT_RANK_WINDOW_SIZE,
298-
defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE),
299-
negatingScoresQueryFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE),
300-
null,
301-
null
302-
);
303-
// create a SearchSource to attach to the request
304-
SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder);
305-
306-
SearchPhaseController controller = searchPhaseController();
307-
SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
308-
309-
MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
310-
mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
311-
try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
312-
// generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
313-
// here we have 2 results, with doc ids 1 and 2
314-
final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
315-
QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null);
316-
317-
try {
318-
queryResult.setShardIndex(shard1Target.getShardId().getId());
319-
int totalHits = randomIntBetween(2, 100);
320-
final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) };
321-
populateQuerySearchResult(queryResult, totalHits, shard1Docs);
322-
results.consumeResult(queryResult, () -> {});
323-
// do not make an actual http request, but rather generate the response
324-
// as if we would have read it from the RankFeatureShardPhase
325-
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
326-
@Override
327-
public void sendExecuteRankFeature(
328-
Transport.Connection connection,
329-
final RankFeatureShardRequest request,
330-
SearchTask task,
331-
final ActionListener<RankFeatureResult> listener
332-
) {
333-
// make sure to match the context id generated above, otherwise we throw
334-
if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) {
335-
listener.onFailure(new UnsupportedOperationException("should not have reached here"));
336-
} else {
337-
listener.onFailure(new MockDirectoryWrapper.FakeIOException());
338-
}
339-
}
340-
};
341-
} finally {
342-
queryResult.decRef();
343-
}
344-
// override the RankFeaturePhase to skip moving to next phase
345-
RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
346-
try {
347-
rankFeaturePhase.run();
348-
mockSearchPhaseContext.assertNoFailure();
349-
assertTrue(mockSearchPhaseContext.failures.isEmpty());
350-
assertTrue(phaseDone.get());
351-
352-
// in this case there was no additional "RankFeature" results on shards, so we shortcut directly to queryPhaseResults
353-
SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.queryPhaseResults;
354-
assertNotNull(rankPhaseResults.getAtomicArray());
355-
assertEquals(1, rankPhaseResults.getAtomicArray().length());
356-
assertEquals(1, rankPhaseResults.getSuccessfulResults().count());
357-
358-
SearchPhaseResult shardResult = rankPhaseResults.getAtomicArray().get(0);
359-
assertTrue(shardResult instanceof QuerySearchResult);
360-
QuerySearchResult rankResult = (QuerySearchResult) shardResult;
361-
assertNull(rankResult.rankFeatureResult());
362-
assertNotNull(rankResult.queryResult());
363-
364-
List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
365-
new ExpectedRankFeatureDoc(2, 1, -9.0F, null),
366-
new ExpectedRankFeatureDoc(1, 2, -10.0F, null)
367-
);
368-
assertFinalResults(finalResults[0], expectedFinalResults);
369-
} finally {
370-
rankFeaturePhase.rankPhaseResults.close();
371-
}
372-
} finally {
373-
if (mockSearchPhaseContext.searchResponse.get() != null) {
374-
mockSearchPhaseContext.searchResponse.get().decRef();
375-
}
376-
}
377-
}
378-
379290
public void testRankFeaturePhaseOneShardFails() {
380291
AtomicBoolean phaseDone = new AtomicBoolean(false);
381292
final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
@@ -534,7 +445,12 @@ public void sendExecuteRankFeature(
534445
queryResult.decRef();
535446
}
536447
// override the RankFeaturePhase to raise an exception
537-
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
448+
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(
449+
results,
450+
null,
451+
mockSearchPhaseContext,
452+
defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE)
453+
) {
538454
@Override
539455
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
540456
throw new IllegalArgumentException("simulated failure");
@@ -890,36 +806,6 @@ public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
890806
};
891807
}
892808

893-
private QueryPhaseRankCoordinatorContext negatingScoresQueryFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
894-
return new QueryPhaseRankCoordinatorContext(rankWindowSize) {
895-
@Override
896-
public ScoreDoc[] rankQueryPhaseResults(
897-
List<QuerySearchResult> rankSearchResults,
898-
SearchPhaseController.TopDocsStats topDocsStats
899-
) {
900-
List<ScoreDoc> docScores = new ArrayList<>();
901-
for (QuerySearchResult phaseResults : rankSearchResults) {
902-
docScores.addAll(Arrays.asList(phaseResults.topDocs().topDocs.scoreDocs));
903-
}
904-
ScoreDoc[] sortedDocs = docScores.toArray(new ScoreDoc[0]);
905-
// negating scores
906-
Arrays.stream(sortedDocs).forEach(doc -> doc.score *= -1);
907-
908-
Arrays.sort(sortedDocs, Comparator.comparing((ScoreDoc doc) -> doc.score).reversed());
909-
sortedDocs = Arrays.stream(sortedDocs).limit(rankWindowSize).toArray(ScoreDoc[]::new);
910-
RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))];
911-
// perform pagination
912-
for (int rank = 0; rank < topResults.length; ++rank) {
913-
ScoreDoc base = sortedDocs[from + rank];
914-
topResults[rank] = new RankFeatureDoc(base.doc, base.score, base.shardIndex);
915-
topResults[rank].rank = from + rank + 1;
916-
}
917-
topDocsStats.fetchHits = topResults.length;
918-
return topResults;
919-
}
920-
};
921-
}
922-
923809
private RankFeaturePhaseRankShardContext defaultRankFeaturePhaseRankShardContext(String field) {
924810
return new RankFeaturePhaseRankShardContext(field) {
925811
@Override
@@ -1134,7 +1020,12 @@ private RankFeaturePhase rankFeaturePhase(
11341020
AtomicBoolean phaseDone
11351021
) {
11361022
// override the RankFeaturePhase to skip moving to next phase
1137-
return new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
1023+
return new RankFeaturePhase(
1024+
results,
1025+
null,
1026+
mockSearchPhaseContext,
1027+
RankFeaturePhase.coordinatorContext(mockSearchPhaseContext.getRequest().source(), null)
1028+
) {
11381029
@Override
11391030
public void moveToNextPhase(
11401031
SearchPhaseResults<SearchPhaseResult> phaseResults,

0 commit comments

Comments
 (0)