Skip to content

Commit 61976a8

Browse files
wrap stream ids on reaching maximum value (#68)
1 parent f340b29 commit 61976a8

File tree

4 files changed

+66
-86
lines changed

4 files changed

+66
-86
lines changed

rsocket-core/src/main/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ internal class RSocketRequester(
9090

9191
private fun handleFireAndForget(payload: Payload): Completable {
9292
return Completable.fromRunnable {
93-
val streamId = streamIds.nextStreamId()
93+
val streamId = streamIds.nextStreamId(receivers)
9494
val requestFrame = Frame.Request.from(
9595
streamId,
9696
FrameType.FIRE_AND_FORGET,
@@ -102,7 +102,7 @@ internal class RSocketRequester(
102102

103103
private fun handleRequestResponse(payload: Payload): Single<Payload> {
104104
return Single.defer {
105-
val streamId = streamIds.nextStreamId()
105+
val streamId = streamIds.nextStreamId(receivers)
106106
val requestFrame = Frame.Request.from(
107107
streamId, FrameType.REQUEST_RESPONSE, payload, 1)
108108

@@ -119,7 +119,7 @@ internal class RSocketRequester(
119119

120120
private fun handleRequestStream(payload: Payload): Flowable<Payload> {
121121
return Flowable.defer {
122-
val streamId = streamIds.nextStreamId()
122+
val streamId = streamIds.nextStreamId(receivers)
123123
val receiver = StreamReceiver.create()
124124
receivers[streamId] = receiver
125125
val reqN = Cond()
@@ -148,7 +148,7 @@ internal class RSocketRequester(
148148
private fun handleChannel(request: Flowable<Payload>): Flowable<Payload> {
149149
return Flowable.defer {
150150
val receiver = StreamReceiver.create()
151-
val streamId = streamIds.nextStreamId()
151+
val streamId = streamIds.nextStreamId(receivers)
152152
val reqN = Cond()
153153

154154
receiver.doOnRequestIfActive { requestN ->
@@ -259,7 +259,7 @@ internal class RSocketRequester(
259259
}
260260
else -> unsupportedFrame(streamId, frame)
261261
}
262-
} ?: missingReceiver(streamId, type, frame)
262+
}
263263
}
264264

265265
private fun unsupportedFrame(streamId: Int, frame: Frame) {
@@ -268,21 +268,6 @@ internal class RSocketRequester(
268268
"$streamId : $frame"))
269269
}
270270

271-
private fun missingReceiver(streamId: Int, type: FrameType, frame: Frame) {
272-
if (!streamIds.isBeforeOrCurrent(streamId)) {
273-
val err = if (type === FrameType.ERROR) {
274-
IllegalStateException(
275-
"Client received error for non-existent stream: " +
276-
"$streamId Message: ${frame.dataUtf8}")
277-
} else {
278-
IllegalStateException(
279-
"Client received message for non-existent stream: " +
280-
"$streamId, frame type: $type")
281-
}
282-
errorConsumer(err)
283-
}
284-
}
285-
286271
private inner class Lifecycle {
287272
private val terminated = AtomicReference<Throwable>()
288273

rsocket-core/src/main/kotlin/io/rsocket/kotlin/internal/StreamIds.kt

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,38 @@
1616

1717
package io.rsocket.kotlin.internal
1818

19-
internal sealed class StreamIds(private var streamId: Int) {
19+
import java.util.concurrent.atomic.AtomicLongFieldUpdater
2020

21-
@Synchronized
22-
fun nextStreamId(): Int {
23-
streamId += 2
21+
22+
internal sealed class StreamIds(streamId: Int) {
23+
24+
@JvmField
25+
@Volatile
26+
internal var streamId: Long = streamId.toLong()
27+
28+
fun nextStreamId(streamIds: Map<Int, *>): Int {
29+
var streamId: Int
30+
do {
31+
val next = STREAM_ID.addAndGet(this, 2)
32+
if (next <= MAX_STREAM_ID) {
33+
return next.toInt()
34+
}
35+
streamId = (next and MASK).toInt()
36+
} while (streamId == 0 || streamIds.containsKey(streamId))
2437
return streamId
2538
}
2639

27-
@Synchronized
28-
fun isBeforeOrCurrent(streamId: Int): Boolean =
29-
this.streamId >= streamId && streamId > 0
40+
companion object {
41+
private val STREAM_ID = AtomicLongFieldUpdater.newUpdater(StreamIds::class.java, "streamId")
42+
private const val MASK: Long = 0x7FFFFFFF
43+
internal const val MAX_STREAM_ID = Int.MAX_VALUE
44+
45+
46+
}
3047
}
3148

3249
internal class ClientStreamIds : StreamIds(-1)
3350

3451
internal class ServerStreamIds : StreamIds(0)
52+
53+
internal class TestStreamIds(streamId: Int) : StreamIds(streamId)

rsocket-core/src/test/kotlin/io/rsocket/kotlin/RSocketRequesterTest.kt

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,6 @@ class RSocketRequesterTest {
4242
@get:Rule
4343
val rule = ClientSocketRule()
4444

45-
@Test(timeout = 2000)
46-
fun testInvalidFrameOnStream0() {
47-
48-
rule.receiver.onNext(Frame.RequestN.from(0, 10))
49-
val errors = rule.errors
50-
51-
assertThat("Unexpected errors.",
52-
errors,
53-
hasSize<Throwable>(1))
54-
assertThat(
55-
"Unexpected error received.",
56-
errors,
57-
contains(instanceOf<Throwable>(IllegalStateException::class.java)))
58-
}
59-
6045
@Test(timeout = 2000)
6146
fun testStreamInitialN() {
6247
val stream = rule.requester.requestStream(DefaultPayload.EMPTY)

rsocket-core/src/test/kotlin/io/rsocket/kotlin/StreamIdsTest.kt

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,71 +18,62 @@ package io.rsocket.kotlin
1818

1919
import io.rsocket.kotlin.internal.ClientStreamIds
2020
import io.rsocket.kotlin.internal.ServerStreamIds
21+
import io.rsocket.kotlin.internal.StreamIds
22+
import io.rsocket.kotlin.internal.TestStreamIds
2123
import org.junit.Assert.assertEquals
22-
import org.junit.Assert.assertFalse
23-
import org.junit.Assert.assertTrue
24-
2524
import org.junit.Test
25+
import java.util.*
26+
import java.util.concurrent.ConcurrentHashMap
2627

2728
class StreamIdsTest {
2829

2930
@Test
3031
fun testClientSequence() {
32+
val map = Collections.emptyMap<Int, Any>()
3133
val s = ClientStreamIds()
32-
assertEquals(1, s.nextStreamId().toLong())
33-
assertEquals(3, s.nextStreamId().toLong())
34-
assertEquals(5, s.nextStreamId().toLong())
34+
assertEquals(1, s.nextStreamId(map).toLong())
35+
assertEquals(3, s.nextStreamId(map).toLong())
36+
assertEquals(5, s.nextStreamId(map).toLong())
3537
}
3638

3739
@Test
3840
fun testServerSequence() {
41+
val map = Collections.emptyMap<Int, Any>()
3942
val s = ServerStreamIds()
40-
assertEquals(2, s.nextStreamId().toLong())
41-
assertEquals(4, s.nextStreamId().toLong())
42-
assertEquals(6, s.nextStreamId().toLong())
43+
assertEquals(2, s.nextStreamId(map).toLong())
44+
assertEquals(4, s.nextStreamId(map).toLong())
45+
assertEquals(6, s.nextStreamId(map).toLong())
4346
}
4447

4548
@Test
46-
fun testClientIsValid() {
47-
val s = ClientStreamIds()
48-
49-
assertFalse(s.isBeforeOrCurrent(1))
50-
assertFalse(s.isBeforeOrCurrent(3))
51-
52-
s.nextStreamId()
53-
assertTrue(s.isBeforeOrCurrent(1))
54-
assertFalse(s.isBeforeOrCurrent(3))
55-
56-
s.nextStreamId()
57-
assertTrue(s.isBeforeOrCurrent(3))
49+
fun testClientSequenceWrap() {
50+
val map = ConcurrentHashMap<Int, Any>()
51+
val s = TestStreamIds(Integer.MAX_VALUE - 2)
5852

59-
// negative
60-
assertFalse(s.isBeforeOrCurrent(-1))
61-
// connection
62-
assertFalse(s.isBeforeOrCurrent(0))
63-
// server also accepted (checked externally)
64-
assertTrue(s.isBeforeOrCurrent(2))
53+
assertEquals(2147483647, s.nextStreamId(map).toLong())
54+
assertEquals(1, s.nextStreamId(map).toLong())
55+
assertEquals(3, s.nextStreamId(map).toLong())
6556
}
6657

6758
@Test
68-
fun testServerIsValid() {
69-
val s = ServerStreamIds()
70-
71-
assertFalse(s.isBeforeOrCurrent(2))
72-
assertFalse(s.isBeforeOrCurrent(4))
73-
74-
s.nextStreamId()
75-
assertTrue(s.isBeforeOrCurrent(2))
76-
assertFalse(s.isBeforeOrCurrent(4))
59+
fun testServerSequenceWrap() {
60+
val map = ConcurrentHashMap<Int, Any>()
61+
val s = TestStreamIds(Integer.MAX_VALUE - 3)
7762

78-
s.nextStreamId()
79-
assertTrue(s.isBeforeOrCurrent(4))
63+
assertEquals(2147483646, s.nextStreamId(map).toLong())
64+
assertEquals(2, s.nextStreamId(map).toLong())
65+
assertEquals(4, s.nextStreamId(map).toLong())
66+
}
8067

81-
// negative
82-
assertFalse(s.isBeforeOrCurrent(-2))
83-
// connection
84-
assertFalse(s.isBeforeOrCurrent(0))
85-
// client also accepted (checked externally)
86-
assertTrue(s.isBeforeOrCurrent(1))
68+
@Test
69+
fun testSequenceSkipsExistingStreamIds() {
70+
val map = ConcurrentHashMap<Int, Any>()
71+
map.put(5, Any())
72+
map.put(9, Any())
73+
val s = TestStreamIds(StreamIds.MAX_STREAM_ID)
74+
assertEquals(1, s.nextStreamId(map).toLong())
75+
assertEquals(3, s.nextStreamId(map).toLong())
76+
assertEquals(7, s.nextStreamId(map).toLong())
77+
assertEquals(11, s.nextStreamId(map).toLong())
8778
}
8879
}

0 commit comments

Comments
 (0)