diff --git a/core/shared/src/test/scala/fs2/interop/flow/SubscriberStabilitySpec.scala b/core/shared/src/test/scala/fs2/interop/flow/SubscriberStabilitySpec.scala index 6d659f801e..d6f63a2e53 100644 --- a/core/shared/src/test/scala/fs2/interop/flow/SubscriberStabilitySpec.scala +++ b/core/shared/src/test/scala/fs2/interop/flow/SubscriberStabilitySpec.scala @@ -28,7 +28,8 @@ import cats.effect.std.Random import java.nio.ByteBuffer import java.util.concurrent.Flow.{Publisher, Subscriber, Subscription} -import scala.concurrent.duration._ +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.duration.* class SubscriberStabilitySpec extends Fs2Suite { val attempts = 100 @@ -67,4 +68,32 @@ class SubscriberStabilitySpec extends Fs2Suite { .replicateA_(attempts) } } + + test("StreamSubscriber cancels subscription on downstream cancellation") { + def makePublisher( + requestCalled: AtomicBoolean, + subscriptionCancelled: AtomicBoolean + ): Publisher[ByteBuffer] = + new Publisher[ByteBuffer] { + + class SubscriptionImpl extends Subscription { + override def request(n: Long): Unit = requestCalled.set(true) + override def cancel(): Unit = subscriptionCancelled.set(true) + } + + override def subscribe(s: Subscriber[? >: ByteBuffer]): Unit = + s.onSubscribe(new SubscriptionImpl) + } + + for { + requestCalled <- IO(new AtomicBoolean(false)) + subscriptionCancelled <- IO(new AtomicBoolean(false)) + publisher = makePublisher(requestCalled, subscriptionCancelled) + _ <- fromPublisher[IO](publisher, chunkSize = 1) + .interruptWhen(Stream.eval(IO(requestCalled.get())).repeat.spaced(10.millis)) + .compile + .drain + _ <- IO(subscriptionCancelled.get).assert + } yield () + } } diff --git a/reactive-streams/src/main/scala/fs2/interop/reactivestreams/StreamSubscriber.scala b/reactive-streams/src/main/scala/fs2/interop/reactivestreams/StreamSubscriber.scala index 3eafdae541..4cb9306820 100644 --- a/reactive-streams/src/main/scala/fs2/interop/reactivestreams/StreamSubscriber.scala +++ b/reactive-streams/src/main/scala/fs2/interop/reactivestreams/StreamSubscriber.scala @@ -216,8 +216,11 @@ object StreamSubscriber { def onComplete(): Unit = nextState(OnComplete) def onFinalize: F[Unit] = F.delay(nextState(OnFinalize)) def dequeue1: F[Either[Throwable, Option[Chunk[A]]]] = - F.async_[Either[Throwable, Option[Chunk[A]]]] { cb => - nextState(OnDequeue(out => cb(Right(out)))) + F.async[Either[Throwable, Option[Chunk[A]]]] { cb => + F.delay { + nextState(OnDequeue(out => cb(Right(out)))) + Some(F.unit) + } } } } diff --git a/reactive-streams/src/test/scala/fs2/interop/reactivestreams/SubscriberStabilitySpec.scala b/reactive-streams/src/test/scala/fs2/interop/reactivestreams/SubscriberStabilitySpec.scala index 202e78ea93..ec89f8228a 100644 --- a/reactive-streams/src/test/scala/fs2/interop/reactivestreams/SubscriberStabilitySpec.scala +++ b/reactive-streams/src/test/scala/fs2/interop/reactivestreams/SubscriberStabilitySpec.scala @@ -23,13 +23,13 @@ package fs2 package interop package reactivestreams -import cats.effect._ +import cats.effect.* import cats.effect.std.Random -import org.reactivestreams._ - -import scala.concurrent.duration._ +import org.reactivestreams.* import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.duration.* class SubscriberStabilitySpec extends Fs2Suite { test("StreamSubscriber has no race condition") { @@ -87,4 +87,32 @@ class SubscriberStabilitySpec extends Fs2Suite { if (failed) fail("Uncaught exception was reported") } + + test("StreamSubscriber cancels subscription on downstream cancellation") { + def makePublisher( + requestCalled: AtomicBoolean, + subscriptionCancelled: AtomicBoolean + ): Publisher[ByteBuffer] = + new Publisher[ByteBuffer] { + + class SubscriptionImpl extends Subscription { + override def request(n: Long): Unit = requestCalled.set(true) + override def cancel(): Unit = subscriptionCancelled.set(true) + } + + override def subscribe(s: Subscriber[? >: ByteBuffer]): Unit = + s.onSubscribe(new SubscriptionImpl) + } + + for { + requestCalled <- IO(new AtomicBoolean(false)) + subscriptionCancelled <- IO(new AtomicBoolean(false)) + publisher = makePublisher(requestCalled, subscriptionCancelled) + _ <- fromPublisher[IO, ByteBuffer](publisher, bufferSize = 1) + .interruptWhen(Stream.eval(IO(requestCalled.get())).repeat.spaced(10.millis)) + .compile + .drain + _ <- IO(subscriptionCancelled.get).assert + } yield () + } }