Skip to content

Commit ccf9893

Browse files
Fix race condition in RemoteClusterService.collectNodes() (#131937)
It is possible for a linked remote to get unlinked in between the containsKey() and get() calls in collectNodes(). This change adds a test that produces the NullPointerException and adds a fix.
1 parent a78737d commit ccf9893

File tree

3 files changed

+99
-23
lines changed

3 files changed

+99
-23
lines changed

docs/changelog/131937.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 131937
2+
summary: Fix race condition in `RemoteClusterService.collectNodes()`
3+
area: Distributed
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.action.support.CountDownActionListener;
1818
import org.elasticsearch.action.support.IndicesOptions;
1919
import org.elasticsearch.action.support.PlainActionFuture;
20+
import org.elasticsearch.action.support.RefCountingListener;
2021
import org.elasticsearch.action.support.RefCountingRunnable;
2122
import org.elasticsearch.client.internal.RemoteClusterClient;
2223
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
@@ -29,7 +30,6 @@
2930
import org.elasticsearch.common.settings.Setting;
3031
import org.elasticsearch.common.settings.Settings;
3132
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
32-
import org.elasticsearch.common.util.concurrent.CountDown;
3333
import org.elasticsearch.common.util.concurrent.EsExecutors;
3434
import org.elasticsearch.core.IOUtils;
3535
import org.elasticsearch.core.TimeValue;
@@ -567,36 +567,26 @@ public void collectNodes(Set<String> clusters, ActionListener<BiFunction<String,
567567
"this node does not have the " + DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE.roleName() + " role"
568568
);
569569
}
570+
final var connectionsMap = new HashMap<String, RemoteClusterConnection>();
570571
for (String cluster : clusters) {
571-
if (this.remoteClusters.containsKey(cluster) == false) {
572+
final var connection = this.remoteClusters.get(cluster);
573+
if (connection == null) {
572574
listener.onFailure(new NoSuchRemoteClusterException(cluster));
573575
return;
574576
}
577+
connectionsMap.put(cluster, connection);
575578
}
576579

577580
final Map<String, Function<String, DiscoveryNode>> clusterMap = new HashMap<>();
578-
CountDown countDown = new CountDown(clusters.size());
579-
Function<String, DiscoveryNode> nullFunction = s -> null;
580-
for (final String cluster : clusters) {
581-
RemoteClusterConnection connection = this.remoteClusters.get(cluster);
582-
connection.collectNodes(new ActionListener<Function<String, DiscoveryNode>>() {
583-
@Override
584-
public void onResponse(Function<String, DiscoveryNode> nodeLookup) {
585-
synchronized (clusterMap) {
586-
clusterMap.put(cluster, nodeLookup);
587-
}
588-
if (countDown.countDown()) {
589-
listener.onResponse((clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId));
590-
}
591-
}
592-
593-
@Override
594-
public void onFailure(Exception e) {
595-
if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures
596-
listener.onFailure(e);
597-
}
581+
final var finalListener = listener.<Void>safeMap(
582+
ignored -> (clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, s -> null).apply(nodeId)
583+
);
584+
try (var refs = new RefCountingListener(finalListener)) {
585+
connectionsMap.forEach((cluster, connection) -> connection.collectNodes(refs.acquire(nodeLookup -> {
586+
synchronized (clusterMap) {
587+
clusterMap.put(cluster, nodeLookup);
598588
}
599-
});
589+
})));
600590
}
601591
}
602592

server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
package org.elasticsearch.transport;
1010

1111
import org.apache.logging.log4j.Level;
12+
import org.apache.lucene.store.AlreadyClosedException;
1213
import org.elasticsearch.TransportVersion;
1314
import org.elasticsearch.action.ActionListener;
15+
import org.elasticsearch.action.LatchedActionListener;
1416
import org.elasticsearch.action.OriginalIndices;
1517
import org.elasticsearch.action.support.ActionTestUtils;
1618
import org.elasticsearch.action.support.IndicesOptions;
@@ -1060,6 +1062,85 @@ public void onFailure(Exception e) {
10601062
}
10611063
}
10621064

1065+
public void testCollectNodesConcurrentWithSettingsChanges() throws IOException {
1066+
final List<DiscoveryNode> knownNodes_c1 = new CopyOnWriteArrayList<>();
1067+
1068+
try (
1069+
var c1N1 = startTransport(
1070+
"cluster_1_node_1",
1071+
knownNodes_c1,
1072+
VersionInformation.CURRENT,
1073+
TransportVersion.current(),
1074+
Settings.EMPTY
1075+
);
1076+
var transportService = MockTransportService.createNewService(
1077+
Settings.EMPTY,
1078+
VersionInformation.CURRENT,
1079+
TransportVersion.current(),
1080+
threadPool,
1081+
null
1082+
)
1083+
) {
1084+
final var c1N1Node = c1N1.getLocalNode();
1085+
knownNodes_c1.add(c1N1Node);
1086+
final var seedList = List.of(c1N1Node.getAddress().toString());
1087+
transportService.start();
1088+
transportService.acceptIncomingRequests();
1089+
1090+
try (RemoteClusterService service = new RemoteClusterService(createSettings("cluster_1", seedList), transportService)) {
1091+
service.initializeRemoteClusters();
1092+
assertTrue(service.isCrossClusterSearchEnabled());
1093+
final var numTasks = between(3, 5);
1094+
final var taskLatch = new CountDownLatch(numTasks);
1095+
1096+
ESTestCase.startInParallel(numTasks, threadNumber -> {
1097+
if (threadNumber == 0) {
1098+
taskLatch.countDown();
1099+
boolean isLinked = true;
1100+
while (taskLatch.getCount() != 0) {
1101+
final var future = new PlainActionFuture<RemoteClusterService.RemoteClusterConnectionStatus>();
1102+
final var settings = createSettings("cluster_1", isLinked ? Collections.emptyList() : seedList);
1103+
service.updateRemoteCluster("cluster_1", settings, future);
1104+
safeGet(future);
1105+
isLinked = isLinked == false;
1106+
}
1107+
return;
1108+
}
1109+
1110+
// Verify collectNodes() always invokes the listener, even if the node is concurrently being unlinked.
1111+
try {
1112+
for (int i = 0; i < 10; ++i) {
1113+
final var latch = new CountDownLatch(1);
1114+
final var exRef = new AtomicReference<Exception>();
1115+
service.collectNodes(Set.of("cluster_1"), new LatchedActionListener<>(new ActionListener<>() {
1116+
@Override
1117+
public void onResponse(BiFunction<String, String, DiscoveryNode> func) {
1118+
assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId()));
1119+
}
1120+
1121+
@Override
1122+
public void onFailure(Exception e) {
1123+
exRef.set(e);
1124+
}
1125+
}, latch));
1126+
safeAwait(latch);
1127+
if (exRef.get() != null) {
1128+
assertThat(
1129+
exRef.get(),
1130+
either(instanceOf(TransportException.class)).or(instanceOf(NoSuchRemoteClusterException.class))
1131+
.or(instanceOf(AlreadyClosedException.class))
1132+
.or(instanceOf(NoSeedNodeLeftException.class))
1133+
);
1134+
}
1135+
}
1136+
} finally {
1137+
taskLatch.countDown();
1138+
}
1139+
});
1140+
}
1141+
}
1142+
}
1143+
10631144
public void testRemoteClusterSkipIfDisconnectedSetting() {
10641145
{
10651146
Settings settings = Settings.builder()

0 commit comments

Comments
 (0)