Skip to content

Commit 9fdf8dc

Browse files
committed
Ensure thread-safe ByteBuffer handling in WebSocket sessions
Replace direct ByteBuffer usage with asReadOnlyBuffer() in binary message sending to prevent concurrent modification issues when sharing buffers across multiple sessions. Signed-off-by: xeroman.k <[email protected]>
1 parent 33897bb commit 9fdf8dc

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ protected void sendTextMessage(TextMessage message) throws IOException {
205205

206206
@Override
207207
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
208-
useSession((session, callback) -> session.sendBinary(message.getPayload(), callback));
208+
useSession((session, callback) -> session.sendBinary(message.getPayload().asReadOnlyBuffer(), callback));
209209
}
210210

211211
@Override

spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ protected void sendTextMessage(TextMessage message) throws IOException {
208208

209209
@Override
210210
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
211-
getNativeSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast());
211+
getNativeSession().getBasicRemote().sendBinary(message.getPayload().asReadOnlyBuffer(), message.isLast());
212212
}
213213

214214
@Override

spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,23 @@
1616

1717
package org.springframework.web.socket.adapter.jetty;
1818

19+
import java.nio.ByteBuffer;
1920
import java.util.Map;
21+
import java.util.function.BiConsumer;
2022

23+
import org.eclipse.jetty.websocket.api.Callback;
2124
import org.eclipse.jetty.websocket.api.Session;
2225
import org.eclipse.jetty.websocket.api.UpgradeRequest;
2326
import org.eclipse.jetty.websocket.api.UpgradeResponse;
2427
import org.junit.jupiter.api.Test;
2528

2629
import org.springframework.core.testfixture.security.TestPrincipal;
30+
import org.springframework.web.socket.BinaryMessage;
2731

2832
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.ArgumentMatchers.any;
2934
import static org.mockito.BDDMockito.given;
35+
import static org.mockito.Mockito.doAnswer;
3036
import static org.mockito.Mockito.mock;
3137
import static org.mockito.Mockito.reset;
3238
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -117,4 +123,48 @@ void getAcceptedProtocol() {
117123
verifyNoMoreInteractions(nativeSession);
118124
}
119125

126+
@Test
127+
void binaryMessageWithSharedBufferSendsToMultipleSessions() throws Exception {
128+
byte[] data = {1, 2, 3, 4, 5};
129+
ByteBuffer sharedBuffer = ByteBuffer.wrap(data);
130+
BinaryMessage message = new BinaryMessage(sharedBuffer);
131+
132+
ByteBuffer[] captured = new ByteBuffer[2];
133+
JettyWebSocketSession session1 = createMockSession((buffer, idx) -> captured[0] = buffer);
134+
JettyWebSocketSession session2 = createMockSession((buffer, idx) -> captured[1] = buffer);
135+
136+
session1.sendMessage(message);
137+
session2.sendMessage(message);
138+
139+
assertThat(captured[0].array()).isEqualTo(data);
140+
assertThat(captured[1].array()).isEqualTo(data);
141+
142+
assertThat(sharedBuffer.position()).isEqualTo(0);
143+
}
144+
145+
private JettyWebSocketSession createMockSession(BiConsumer<ByteBuffer, Integer> captureFunction) {
146+
Session mockSession = mock(Session.class);
147+
148+
given(mockSession.getUpgradeRequest()).willReturn(request);
149+
given(mockSession.getUpgradeResponse()).willReturn(response);
150+
given(mockSession.isOpen()).willReturn(true);
151+
given(request.getUserPrincipal()).willReturn(null);
152+
given(response.getAcceptedSubProtocol()).willReturn(null);
153+
154+
doAnswer(invocation -> {
155+
ByteBuffer buffer = invocation.getArgument(0);
156+
Callback callback = invocation.getArgument(1);
157+
ByteBuffer copy = ByteBuffer.allocate(buffer.remaining());
158+
copy.put(buffer);
159+
copy.flip();
160+
captureFunction.accept(copy, 0);
161+
callback.succeed();
162+
return null;
163+
}).when(mockSession).sendBinary(any(ByteBuffer.class), any(Callback.class));
164+
165+
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
166+
session.initializeNativeSession(mockSession);
167+
return session;
168+
}
169+
120170
}

spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,24 @@
1616

1717
package org.springframework.web.socket.adapter.standard;
1818

19+
import java.nio.ByteBuffer;
1920
import java.util.HashMap;
2021
import java.util.Map;
22+
import java.util.function.BiConsumer;
2123

24+
import jakarta.websocket.RemoteEndpoint.Basic;
2225
import jakarta.websocket.Session;
2326
import org.junit.jupiter.api.Test;
2427

2528
import org.springframework.core.testfixture.security.TestPrincipal;
2629
import org.springframework.http.HttpHeaders;
30+
import org.springframework.web.socket.BinaryMessage;
2731

2832
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.ArgumentMatchers.any;
34+
import static org.mockito.ArgumentMatchers.anyBoolean;
2935
import static org.mockito.BDDMockito.given;
36+
import static org.mockito.Mockito.doAnswer;
3037
import static org.mockito.Mockito.mock;
3138
import static org.mockito.Mockito.reset;
3239
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -106,4 +113,44 @@ void addAttributesWithNullKeyOrValue() {
106113
.hasSize(1).containsEntry("foo", "bar");
107114
}
108115

116+
@Test
117+
void binaryMessageWithSharedBufferSendsToMultipleSessions() throws Exception {
118+
byte[] data = {1, 2, 3, 4, 5};
119+
ByteBuffer sharedBuffer = ByteBuffer.wrap(data);
120+
BinaryMessage message = new BinaryMessage(sharedBuffer);
121+
122+
ByteBuffer[] captured = new ByteBuffer[2];
123+
StandardWebSocketSession session1 = createMockSession((buffer, idx) -> captured[0] = buffer);
124+
StandardWebSocketSession session2 = createMockSession((buffer, idx) -> captured[1] = buffer);
125+
126+
session1.sendMessage(message);
127+
session2.sendMessage(message);
128+
129+
assertThat(captured[0].array()).isEqualTo(data);
130+
assertThat(captured[1].array()).isEqualTo(data);
131+
132+
assertThat(sharedBuffer.position()).isEqualTo(0);
133+
}
134+
135+
private StandardWebSocketSession createMockSession(BiConsumer<ByteBuffer, Integer> captureFunction) throws Exception {
136+
Session nativeSession = mock(Session.class);
137+
Basic basicRemote = mock(Basic.class);
138+
139+
given(nativeSession.getBasicRemote()).willReturn(basicRemote);
140+
given(nativeSession.isOpen()).willReturn(true);
141+
142+
doAnswer(invocation -> {
143+
ByteBuffer buffer = invocation.getArgument(0);
144+
ByteBuffer copy = ByteBuffer.allocate(buffer.remaining());
145+
copy.put(buffer);
146+
copy.flip();
147+
captureFunction.accept(copy, 0);
148+
return null;
149+
}).when(basicRemote).sendBinary(any(ByteBuffer.class), anyBoolean());
150+
151+
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
152+
session.initializeNativeSession(nativeSession);
153+
return session;
154+
}
155+
109156
}

0 commit comments

Comments
 (0)