Skip to content

Commit 6078946

Browse files
committed
Add test to reproduce ByteBuffer concurrency issue in WebSocket sessions
- Test concurrent sending of shared BinaryMessage across multiple sessions - Demonstrates ByteBuffer position corruption in multi-threaded environment - Validates the need for ByteBuffer.duplicate() when handling binary messages
1 parent 33897bb commit 6078946

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,21 @@
1717
package org.springframework.web.socket.handler;
1818

1919
import java.io.IOException;
20+
import java.nio.ByteBuffer;
21+
import java.util.ArrayList;
22+
import java.util.List;
2023
import java.util.concurrent.CountDownLatch;
24+
import java.util.concurrent.ExecutorService;
2125
import java.util.concurrent.Executors;
2226
import java.util.concurrent.TimeUnit;
27+
import java.util.concurrent.atomic.AtomicInteger;
2328

2429
import org.junit.jupiter.api.Test;
2530

31+
import org.springframework.web.socket.BinaryMessage;
2632
import org.springframework.web.socket.CloseStatus;
2733
import org.springframework.web.socket.TextMessage;
34+
import org.springframework.web.socket.WebSocketMessage;
2835
import org.springframework.web.socket.WebSocketSession;
2936
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator.OverflowStrategy;
3037

@@ -226,4 +233,119 @@ void configuredProperties() {
226233
assertThat(sessionDecorator.getOverflowStrategy()).isEqualTo(OverflowStrategy.DROP);
227234
}
228235

236+
@Test
237+
void concurrentBinaryMessageSharingAcrossSessions() throws Exception {
238+
byte[] originalData = new byte[100];
239+
for (int i = 0; i < originalData.length; i++) {
240+
originalData[i] = (byte) i;
241+
}
242+
ByteBuffer buffer = ByteBuffer.wrap(originalData);
243+
BinaryMessage sharedMessage = new BinaryMessage(buffer);
244+
245+
int sessionCount = 5;
246+
int messagesPerSession = 3;
247+
CountDownLatch startLatch = new CountDownLatch(1);
248+
CountDownLatch completeLatch = new CountDownLatch(sessionCount * messagesPerSession);
249+
AtomicInteger corruptedBuffers = new AtomicInteger(0);
250+
251+
List<TestBinaryMessageCapturingSession> sessions = new ArrayList<>();
252+
List<ConcurrentWebSocketSessionDecorator> decorators = new ArrayList<>();
253+
254+
for (int i = 0; i < sessionCount; i++) {
255+
TestBinaryMessageCapturingSession session = new TestBinaryMessageCapturingSession();
256+
session.setOpen(true);
257+
session.setId("session-" + i);
258+
sessions.add(session);
259+
260+
ConcurrentWebSocketSessionDecorator decorator =
261+
new ConcurrentWebSocketSessionDecorator(session, 10000, 10240);
262+
decorators.add(decorator);
263+
}
264+
265+
ExecutorService executor = Executors.newFixedThreadPool(sessionCount * messagesPerSession);
266+
267+
try {
268+
for (ConcurrentWebSocketSessionDecorator decorator : decorators) {
269+
for (int j = 0; j < messagesPerSession; j++) {
270+
executor.submit(() -> {
271+
try {
272+
startLatch.await();
273+
decorator.sendMessage(sharedMessage);
274+
} catch (Exception e) {
275+
e.printStackTrace();
276+
} finally {
277+
completeLatch.countDown();
278+
}
279+
});
280+
}
281+
}
282+
283+
startLatch.countDown();
284+
assertThat(completeLatch.await(5, TimeUnit.SECONDS)).isTrue();
285+
286+
for (TestBinaryMessageCapturingSession session : sessions) {
287+
List<ByteBuffer> capturedBuffers = session.getCapturedBuffers();
288+
289+
for (ByteBuffer capturedBuffer : capturedBuffers) {
290+
byte[] capturedData = new byte[capturedBuffer.remaining()];
291+
capturedBuffer.get(capturedData);
292+
293+
boolean isCorrupted = false;
294+
if (capturedData.length != originalData.length) {
295+
isCorrupted = true;
296+
} else {
297+
for (int j = 0; j < originalData.length; j++) {
298+
if (capturedData[j] != originalData[j]) {
299+
isCorrupted = true;
300+
break;
301+
}
302+
}
303+
}
304+
305+
if (isCorrupted) {
306+
corruptedBuffers.incrementAndGet();
307+
}
308+
}
309+
}
310+
311+
assertThat(corruptedBuffers.get())
312+
.as("No ByteBuffer corruption should occur with duplicate() fix")
313+
.isEqualTo(0);
314+
} finally {
315+
executor.shutdown();
316+
}
317+
}
318+
319+
static class TestBinaryMessageCapturingSession extends TestWebSocketSession {
320+
private final List<ByteBuffer> capturedBuffers = new ArrayList<>();
321+
322+
@Override
323+
public void sendMessage(WebSocketMessage<?> message) throws IOException {
324+
if (message instanceof BinaryMessage) {
325+
ByteBuffer payload = ((BinaryMessage) message).getPayload();
326+
ByteBuffer captured = ByteBuffer.allocate(payload.remaining());
327+
328+
while (payload.hasRemaining()) {
329+
captured.put(payload.get());
330+
}
331+
captured.flip();
332+
333+
synchronized (capturedBuffers) {
334+
capturedBuffers.add(captured);
335+
}
336+
337+
try {
338+
Thread.sleep(1);
339+
} catch (InterruptedException e) {
340+
Thread.currentThread().interrupt();
341+
}
342+
}
343+
super.sendMessage(message);
344+
}
345+
346+
public synchronized List<ByteBuffer> getCapturedBuffers() {
347+
return new ArrayList<>(capturedBuffers);
348+
}
349+
}
350+
229351
}

0 commit comments

Comments
 (0)