From 48a2b88c68988daf3dd78d36e3e884b268e18e1e Mon Sep 17 00:00:00 2001 From: Max Komarychev Date: Wed, 28 May 2025 00:02:20 +0200 Subject: [PATCH] fix: respect custom coroutine scope when calling blocking functions in a data fetcher When running blocking function data fetcher it does not respect coroutine scope provided in the GraphQL context. In my project we're creating custom a coroutine scope in order to propagate MDC and OTEL contexts in a coroutine like this: ``` CoroutineScope(MDCContext() + Dispatchers.IO + Context.current().asContextElement()) ``` (here `Context` is an instance of OTEL context). This works well in suspended functions since they are using the context however normal functions are not using this context and therefore none of the OTEL context (trace id etc) is propagated. --- .../execution/FunctionDataFetcher.kt | 13 ++- .../execution/FunctionDataFetcherTest.kt | 102 ++++++++++++++++-- 2 files changed, 104 insertions(+), 11 deletions(-) diff --git a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt index 645ea39718..173d6b469e 100644 --- a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt +++ b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt @@ -24,6 +24,7 @@ import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.future.future +import kotlinx.coroutines.runBlocking import java.lang.reflect.InvocationTargetException import java.util.concurrent.CompletableFuture import kotlin.coroutines.EmptyCoroutineContext @@ -60,7 +61,7 @@ open class FunctionDataFetcher( if (fn.isSuspend) { runSuspendingFunction(environment, parameterValues) } else { - runBlockingFunction(parameterValues) + runBlockingFunction(environment, parameterValues) } } else { null @@ -123,8 +124,14 @@ open class FunctionDataFetcher( * Once all parameters values are properly converted, this function will be called to run a simple blocking function. * If you need to override the exception handling you can override this method. */ - protected open fun runBlockingFunction(parameterValues: Map): Any? = try { - fn.callBy(parameterValues) + protected open fun runBlockingFunction( + environment: DataFetchingEnvironment, + parameterValues: Map + ): Any? = try { + val coroutineScope = environment.graphQlContext.getOrDefault(CoroutineScope(EmptyCoroutineContext)) + runBlocking(coroutineScope.coroutineContext) { + fn.callBy(parameterValues) + } } catch (exception: InvocationTargetException) { throw exception.cause ?: exception } diff --git a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt index 1907788742..91f8c5364f 100644 --- a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt +++ b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt @@ -22,14 +22,15 @@ import graphql.GraphQLException import graphql.schema.DataFetchingEnvironment import io.mockk.every import io.mockk.mockk +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import java.util.concurrent.CompletableFuture -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertNotNull -import kotlin.test.assertNull -import kotlin.test.assertTrue +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory +import kotlin.test.* class FunctionDataFetcherTest { @@ -37,6 +38,15 @@ class FunctionDataFetcherTest { fun print(string: String): String } + val threadFactory = object: ThreadFactory { + override fun newThread(r: Runnable): Thread? { + val thread = Thread(r) + thread.name = "custom-thread-1" + return thread + } + } + val customCoroutineDispatcher = Executors.newSingleThreadExecutor(threadFactory).asCoroutineDispatcher() + class MyClass : MyInterface { override fun print(string: String) = string @@ -52,7 +62,9 @@ class FunctionDataFetcherTest { string } - fun throwException() { throw GraphQLException("Test Exception") } + fun throwException() { + throw GraphQLException("Test Exception") + } suspend fun suspendThrow(): String = coroutineScope { throw GraphQLException("Suspended Exception") @@ -75,14 +87,21 @@ class FunctionDataFetcherTest { is OptionalInput.Undefined -> "optional was UNDEFINED" is OptionalInput.Defined -> "optional was ${input.optional.value}" } + + fun threadNameSync(): String { + return Thread.currentThread().name + } + + fun threadNameAsync(): String { + return Thread.currentThread().name + } } data class InputWrapper(val required: String, val optional: OptionalInput) @GraphQLName("MyInputClassRenamed") data class MyInputClass( - @GraphQLName("jacksonField") - val field1: String + @GraphQLName("jacksonField") val field1: String ) @Test @@ -101,6 +120,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -111,6 +131,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -137,6 +158,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -148,6 +170,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -159,6 +182,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "foo") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "foo", actual = dataFetcher.get(mockEnvironment)) } @@ -170,6 +194,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to null) every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertNull(dataFetcher.get(mockEnvironment)) } @@ -180,6 +205,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("items" to listOf("foo", "bar")) every { containsArgument("items") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "foo:bar", actual = dataFetcher.get(mockEnvironment)) @@ -194,6 +220,7 @@ class FunctionDataFetcherTest { every { field } returns mockk { every { name } returns "fooBarBaz" } + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "fooBarBaz", actual = dataFetcher.get(mockEnvironment)) } @@ -219,6 +246,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } try { @@ -256,6 +284,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("myCustomArgument" to mapOf("jacksonField" to "foo")) every { containsArgument("myCustomArgument") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "You sent foo", actual = dataFetcher.get(mockEnvironment)) } @@ -266,6 +295,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to "hello") every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was hello", actual = dataFetcher.get(mockEnvironment)) } @@ -276,6 +306,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to null) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was null", actual = dataFetcher.get(mockEnvironment)) } @@ -286,6 +317,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was UNDEFINED", actual = dataFetcher.get(mockEnvironment)) } @@ -296,6 +328,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to listOf(linkedMapOf("jacksonField" to "foo"))) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } val result = dataFetcher.get(mockEnvironment) assertEquals(expected = "first input was foo", actual = result) @@ -307,6 +340,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to mapOf("required" to "hello", "optional" to "hello")) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "optional was hello", actual = dataFetcher.get(mockEnvironment)) } @@ -317,7 +351,59 @@ class FunctionDataFetcherTest { val mockEnvironment = mockk { every { arguments } returns mapOf("input" to mapOf("required" to "hello")) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "optional was UNDEFINED", actual = dataFetcher.get(mockEnvironment)) } + + @Test + fun `use default scope for sync function`() { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "Test worker") + } + + @Test + fun `use provided scope for sync function`() = runBlocking { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync) + val scope = CoroutineScope(customCoroutineDispatcher) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "custom-thread-1") + } + + + @Test + fun `use default scope for async function`() { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "Test worker") + } + + @Test + fun `use provided scope for async function`() = runBlocking { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync) + val scope = CoroutineScope(customCoroutineDispatcher) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "custom-thread-1") + } }