diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index ed0491e842..dce5b2f69e 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -1471,7 +1471,16 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits with TuplePara def async_[A](k: (Either[Throwable, A] => Unit) => Unit): IO[A] = { val body = new Cont[IO, A, A] { def apply[G[_]](implicit G: MonadCancel[G, Throwable]) = { (resume, get, lift) => - G.uncancelable(_ => lift(IO.delay(k(resume))).flatMap(_ => get)) + G.uncancelable(_ => + try { + k(resume) + get + } catch { + case t if UnsafeNonFatal(t) => + lift(IO.raiseError(t)) + case t: Throwable => + throw t + }) } } diff --git a/ioapp-tests/src/test/scala/IOAppSpec.scala b/ioapp-tests/src/test/scala/IOAppSpec.scala index 7d0eff4e1a..8bd7dfa309 100644 --- a/ioapp-tests/src/test/scala/IOAppSpec.scala +++ b/ioapp-tests/src/test/scala/IOAppSpec.scala @@ -203,13 +203,6 @@ class IOAppSpec extends Specification { h.stdout() must not(contain("sadness")) } - "exit on raising a fatal error inside a map" in { - val h = platform("RaiseFatalErrorMap", List.empty) - h.awaitStatus() mustEqual 1 - h.stderr() must contain("Boom!") - h.stdout() must not(contain("sadness")) - } - "exit on raising a fatal error inside a flatMap" in { val h = platform("RaiseFatalErrorFlatMap", List.empty) h.awaitStatus() mustEqual 1 @@ -217,6 +210,30 @@ class IOAppSpec extends Specification { h.stdout() must not(contain("sadness")) } + "exit on fatal error from CompletableFuture" in { + if (platform == JVM) { + val h = platform("FatalErrorFromCompletableFuture", List.empty) + h.awaitStatus() mustEqual 1 + h.stderr() must contain("Boom from CompletableFuture!") + h.stdout() must not(contain("sadness")) + } else { + // CompletableFuture is JVM-only + ok + } + } + + "exit on fatal error from async_" in { + if (platform == JVM) { + val h = platform("FatalErrorFromAsync", List.empty) + h.awaitStatus() mustEqual 1 + h.stderr() must contain("Boom from async!") + h.stdout() must not(contain("sadness")) + } else { + // Fatal error testing is JVM-only + ok + } + } + "warn on global runtime collision" in { val h = platform("GlobalRacingInit", List.empty) h.awaitStatus() mustEqual 0 diff --git a/kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala b/kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala index b0e6492770..6c8ac1665e 100644 --- a/kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala +++ b/kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala @@ -17,10 +17,18 @@ package cats package effect.kernel +import scala.util.control.ControlThrowable + import java.util.concurrent.{CompletableFuture, CompletionException, CompletionStage} private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F] => + private def isNonFatal(t: Throwable): Boolean = t match { + case _: VirtualMachineError | _: ThreadDeath | _: LinkageError | _: ControlThrowable => + false + case _ => true + } + def fromCompletionStage[A](completionStage: F[CompletionStage[A]]): F[A] = fromCompletableFuture(flatMap(completionStage) { cs => delay(cs.toCompletableFuture()) }) @@ -50,10 +58,15 @@ private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F] cf.handle[Unit] { case (a, null) => resume(Right(a)) case (_, t) => - resume(Left(t match { + val actualThrowable = t match { case e: CompletionException if e.getCause ne null => e.getCause case _ => t - })) + } + if (isNonFatal(actualThrowable)) { + resume(Left(actualThrowable)) + } else { + throw actualThrowable + } } } diff --git a/tests/jvm/src/main/scala/catseffect/examplesplatform.scala b/tests/jvm/src/main/scala/catseffect/examplesplatform.scala index fddb6e17e1..a0c9bb3f50 100644 --- a/tests/jvm/src/main/scala/catseffect/examplesplatform.scala +++ b/tests/jvm/src/main/scala/catseffect/examplesplatform.scala @@ -92,4 +92,24 @@ package examples { val run = IO.cede.foreverM.start >> IO(Thread.sleep(2.seconds.toMillis)) } + + object FatalErrorFromCompletableFuture extends IOApp { + def run(args: List[String]): IO[ExitCode] = { + import java.util.concurrent.CompletableFuture + + IO.fromCompletableFuture(IO(CompletableFuture.runAsync(() => { + throw new OutOfMemoryError("Boom from CompletableFuture!") + }))) + .flatMap(_ => IO.println("sadness")) + .as(ExitCode.Success) + } + } + + object FatalErrorFromAsync extends IOApp { + def run(args: List[String]): IO[ExitCode] = { + IO.async_[Unit] { cb => cb(Left(new OutOfMemoryError("Boom from async!"))) } + .flatMap(_ => IO.println("sadness")) + .as(ExitCode.Success) + } + } }