diff --git a/apollo-execution-ktor/api/apollo-execution-ktor.api b/apollo-execution-ktor/api/apollo-execution-ktor.api index fbf4309a..91dad7b5 100644 --- a/apollo-execution-ktor/api/apollo-execution-ktor.api +++ b/apollo-execution-ktor/api/apollo-execution-ktor.api @@ -3,10 +3,18 @@ public final class com/apollographql/execution/ktor/MainKt { public static synthetic fun apolloModule$default (Lio/ktor/server/application/Application;Lcom/apollographql/apollo/execution/ExecutableSchema;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun apolloSandboxModule (Lio/ktor/server/application/Application;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static synthetic fun apolloSandboxModule$default (Lio/ktor/server/application/Application;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static final fun apolloSubscriptionModule (Lio/ktor/server/application/Application;Lcom/apollographql/apollo/execution/ExecutableSchema;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun apolloSubscriptionModule$default (Lio/ktor/server/application/Application;Lcom/apollographql/apollo/execution/ExecutableSchema;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun apolloSubscriptionModule (Lio/ktor/server/application/Application;Lcom/apollographql/apollo/execution/ExecutableSchema;Ljava/lang/String;Lcom/apollographql/execution/ktor/WsProtocol;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun apolloSubscriptionModule$default (Lio/ktor/server/application/Application;Lcom/apollographql/apollo/execution/ExecutableSchema;Ljava/lang/String;Lcom/apollographql/execution/ktor/WsProtocol;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun parseAsGraphQLRequest (Lio/ktor/server/request/ApplicationRequest;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun respondGraphQL (Lio/ktor/server/application/ApplicationCall;Lcom/apollographql/apollo/execution/ExecutableSchema;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun respondGraphQL$default (Lio/ktor/server/application/ApplicationCall;Lcom/apollographql/apollo/execution/ExecutableSchema;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public final class com/apollographql/execution/ktor/WsProtocol : java/lang/Enum { + public static final field GraphqlWS Lcom/apollographql/execution/ktor/WsProtocol; + public static final field Legacy Lcom/apollographql/execution/ktor/WsProtocol; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lcom/apollographql/execution/ktor/WsProtocol; + public static fun values ()[Lcom/apollographql/execution/ktor/WsProtocol; +} + diff --git a/apollo-execution-ktor/src/commonMain/kotlin/com/apollographql/execution/ktor/main.kt b/apollo-execution-ktor/src/commonMain/kotlin/com/apollographql/execution/ktor/main.kt index 28a13445..98fabe5e 100644 --- a/apollo-execution-ktor/src/commonMain/kotlin/com/apollographql/execution/ktor/main.kt +++ b/apollo-execution-ktor/src/commonMain/kotlin/com/apollographql/execution/ktor/main.kt @@ -7,9 +7,11 @@ import com.apollographql.apollo.execution.GraphQLResponse import com.apollographql.apollo.execution.parseAsGraphQLRequest import com.apollographql.execution.* import com.apollographql.execution.websocket.ConnectionInitAck +import com.apollographql.execution.websocket.GraphQLWsWebSocketHandler import com.apollographql.execution.websocket.SubscriptionWebSocketHandler import com.apollographql.execution.websocket.WebSocketBinaryMessage import com.apollographql.execution.websocket.WebSocketTextMessage +import com.apollographql.execution.websocket.WsConnectionInitAck import io.ktor.http.* import io.ktor.http.content.* import io.ktor.server.application.* @@ -99,30 +101,59 @@ fun Application.apolloModule( } } +enum class WsProtocol { + GraphqlWS, + Legacy +} fun Application.apolloSubscriptionModule( executableSchema: ExecutableSchema, path: String = "/subscription", + protocol: WsProtocol = WsProtocol.GraphqlWS, executionContext: (ApplicationRequest) -> ExecutionContext = { ExecutionContext.Empty } ) { install(WebSockets) routing { - webSocket(path, "graphql-ws") { + val transport = when (protocol) { + WsProtocol.GraphqlWS -> "graphql-transport-ws" + WsProtocol.Legacy -> "graphql-ws" + } + webSocket(path, transport) { coroutineScope { - val handler = SubscriptionWebSocketHandler( - executableSchema = executableSchema, - scope = this, - executionContext = executionContext(call.request), - sendMessage = { - when (it) { - is WebSocketBinaryMessage -> send(Frame.Binary(true, it.data)) - is WebSocketTextMessage -> send(Frame.Text(it.data)) - } - }, - connectionInitHandler = { - ConnectionInitAck + val handler = when(protocol) { + WsProtocol.GraphqlWS -> { + GraphQLWsWebSocketHandler( + executableSchema = executableSchema, + scope = this, + executionContext = executionContext(call.request), + sendMessage = { + when (it) { + is WebSocketBinaryMessage -> send(Frame.Binary(true, it.data)) + is WebSocketTextMessage -> send(Frame.Text(it.data)) + } + }, + connectionInitHandler = { + WsConnectionInitAck + } + ) } - ) + WsProtocol.Legacy -> { + SubscriptionWebSocketHandler( + executableSchema = executableSchema, + scope = this, + executionContext = executionContext(call.request), + sendMessage = { + when (it) { + is WebSocketBinaryMessage -> send(Frame.Binary(true, it.data)) + is WebSocketTextMessage -> send(Frame.Text(it.data)) + } + }, + connectionInitHandler = { + ConnectionInitAck + } + ) + } + } for (frame in incoming) { if (frame !is Frame.Text) { diff --git a/apollo-execution-runtime/api/apollo-execution-runtime.api b/apollo-execution-runtime/api/apollo-execution-runtime.api index 4c3060a2..d7faed3e 100644 --- a/apollo-execution-runtime/api/apollo-execution-runtime.api +++ b/apollo-execution-runtime/api/apollo-execution-runtime.api @@ -62,15 +62,18 @@ public final class com/apollographql/execution/websocket/ConnectionInitError : c public abstract interface class com/apollographql/execution/websocket/ConnectionInitResult { } -public final class com/apollographql/execution/websocket/SubscriptionWebSocketHandler : com/apollographql/execution/websocket/WebSocketHandler { +public final class com/apollographql/execution/websocket/GraphQLWsWebSocketHandler : com/apollographql/execution/websocket/WebSocketHandler { public fun (Lcom/apollographql/apollo/execution/ExecutableSchema;Lkotlinx/coroutines/CoroutineScope;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V public synthetic fun (Lcom/apollographql/apollo/execution/ExecutableSchema;Lkotlinx/coroutines/CoroutineScope;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun close ()V public fun handleMessage (Lcom/apollographql/execution/websocket/WebSocketMessage;)V } -public final class com/apollographql/execution/websocket/SubscriptionWebSocketHandlerKt { - public static final fun subscriptionId (Lcom/apollographql/apollo/api/ExecutionContext;)Ljava/lang/String; +public final class com/apollographql/execution/websocket/SubscriptionWebSocketHandler : com/apollographql/execution/websocket/WebSocketHandler { + public fun (Lcom/apollographql/apollo/execution/ExecutableSchema;Lkotlinx/coroutines/CoroutineScope;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V + public synthetic fun (Lcom/apollographql/apollo/execution/ExecutableSchema;Lkotlinx/coroutines/CoroutineScope;Lcom/apollographql/apollo/api/ExecutionContext;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun close ()V + public fun handleMessage (Lcom/apollographql/execution/websocket/WebSocketMessage;)V } public final class com/apollographql/execution/websocket/WebSocketBinaryMessage : com/apollographql/execution/websocket/WebSocketMessage { @@ -90,3 +93,20 @@ public final class com/apollographql/execution/websocket/WebSocketTextMessage : public final fun getData ()Ljava/lang/String; } +public final class com/apollographql/execution/websocket/WsConnectionInitAck : com/apollographql/execution/websocket/WsConnectionInitResult { + public static final field INSTANCE Lcom/apollographql/execution/websocket/WsConnectionInitAck; + public fun equals (Ljava/lang/Object;)Z + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class com/apollographql/execution/websocket/WsConnectionInitError : com/apollographql/execution/websocket/WsConnectionInitResult { + public fun ()V + public fun (Lcom/apollographql/apollo/api/Optional;)V + public synthetic fun (Lcom/apollographql/apollo/api/Optional;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getPayload ()Lcom/apollographql/apollo/api/Optional; +} + +public abstract interface class com/apollographql/execution/websocket/WsConnectionInitResult { +} + diff --git a/apollo-execution-runtime/api/apollo-execution-runtime.klib.api b/apollo-execution-runtime/api/apollo-execution-runtime.klib.api index 88f797eb..607342fb 100644 --- a/apollo-execution-runtime/api/apollo-execution-runtime.klib.api +++ b/apollo-execution-runtime/api/apollo-execution-runtime.klib.api @@ -55,6 +55,8 @@ sealed interface com.apollographql.execution.websocket/ConnectionInitResult // c sealed interface com.apollographql.execution.websocket/WebSocketMessage // com.apollographql.execution.websocket/WebSocketMessage|null[0] +sealed interface com.apollographql.execution.websocket/WsConnectionInitResult // com.apollographql.execution.websocket/WsConnectionInitResult|null[0] + final class com.apollographql.execution.websocket/ConnectionInitError : com.apollographql.execution.websocket/ConnectionInitResult { // com.apollographql.execution.websocket/ConnectionInitError|null[0] constructor (com.apollographql.apollo.api/Optional = ...) // com.apollographql.execution.websocket/ConnectionInitError.|(com.apollographql.apollo.api.Optional){}[0] @@ -62,6 +64,13 @@ final class com.apollographql.execution.websocket/ConnectionInitError : com.apol final fun (): com.apollographql.apollo.api/Optional // com.apollographql.execution.websocket/ConnectionInitError.payload.|(){}[0] } +final class com.apollographql.execution.websocket/GraphQLWsWebSocketHandler : com.apollographql.execution.websocket/WebSocketHandler { // com.apollographql.execution.websocket/GraphQLWsWebSocketHandler|null[0] + constructor (com.apollographql.apollo.execution/ExecutableSchema, kotlinx.coroutines/CoroutineScope, com.apollographql.apollo.api/ExecutionContext, kotlin.coroutines/SuspendFunction1, kotlin.coroutines/SuspendFunction1 = ...) // com.apollographql.execution.websocket/GraphQLWsWebSocketHandler.|(com.apollographql.apollo.execution.ExecutableSchema;kotlinx.coroutines.CoroutineScope;com.apollographql.apollo.api.ExecutionContext;kotlin.coroutines.SuspendFunction1;kotlin.coroutines.SuspendFunction1){}[0] + + final fun close() // com.apollographql.execution.websocket/GraphQLWsWebSocketHandler.close|close(){}[0] + final fun handleMessage(com.apollographql.execution.websocket/WebSocketMessage) // com.apollographql.execution.websocket/GraphQLWsWebSocketHandler.handleMessage|handleMessage(com.apollographql.execution.websocket.WebSocketMessage){}[0] +} + final class com.apollographql.execution.websocket/SubscriptionWebSocketHandler : com.apollographql.execution.websocket/WebSocketHandler { // com.apollographql.execution.websocket/SubscriptionWebSocketHandler|null[0] constructor (com.apollographql.apollo.execution/ExecutableSchema, kotlinx.coroutines/CoroutineScope, com.apollographql.apollo.api/ExecutionContext, kotlin.coroutines/SuspendFunction1, kotlin.coroutines/SuspendFunction1 = ...) // com.apollographql.execution.websocket/SubscriptionWebSocketHandler.|(com.apollographql.apollo.execution.ExecutableSchema;kotlinx.coroutines.CoroutineScope;com.apollographql.apollo.api.ExecutionContext;kotlin.coroutines.SuspendFunction1;kotlin.coroutines.SuspendFunction1){}[0] @@ -83,6 +92,13 @@ final class com.apollographql.execution.websocket/WebSocketTextMessage : com.apo final fun (): kotlin/String // com.apollographql.execution.websocket/WebSocketTextMessage.data.|(){}[0] } +final class com.apollographql.execution.websocket/WsConnectionInitError : com.apollographql.execution.websocket/WsConnectionInitResult { // com.apollographql.execution.websocket/WsConnectionInitError|null[0] + constructor (com.apollographql.apollo.api/Optional = ...) // com.apollographql.execution.websocket/WsConnectionInitError.|(com.apollographql.apollo.api.Optional){}[0] + + final val payload // com.apollographql.execution.websocket/WsConnectionInitError.payload|{}payload[0] + final fun (): com.apollographql.apollo.api/Optional // com.apollographql.execution.websocket/WsConnectionInitError.payload.|(){}[0] +} + final class com.apollographql.execution/CompositeResolverBuilder { // com.apollographql.execution/CompositeResolverBuilder|null[0] constructor () // com.apollographql.execution/CompositeResolverBuilder.|(){}[0] @@ -103,6 +119,11 @@ final object com.apollographql.execution.websocket/ConnectionInitAck : com.apoll final fun toString(): kotlin/String // com.apollographql.execution.websocket/ConnectionInitAck.toString|toString(){}[0] } -final fun (com.apollographql.apollo.api/ExecutionContext).com.apollographql.execution.websocket/subscriptionId(): kotlin/String // com.apollographql.execution.websocket/subscriptionId|subscriptionId@com.apollographql.apollo.api.ExecutionContext(){}[0] +final object com.apollographql.execution.websocket/WsConnectionInitAck : com.apollographql.execution.websocket/WsConnectionInitResult { // com.apollographql.execution.websocket/WsConnectionInitAck|null[0] + final fun equals(kotlin/Any?): kotlin/Boolean // com.apollographql.execution.websocket/WsConnectionInitAck.equals|equals(kotlin.Any?){}[0] + final fun hashCode(): kotlin/Int // com.apollographql.execution.websocket/WsConnectionInitAck.hashCode|hashCode(){}[0] + final fun toString(): kotlin/String // com.apollographql.execution.websocket/WsConnectionInitAck.toString|toString(){}[0] +} + final fun (com.apollographql.apollo.execution/ExecutableSchema.Builder).com.apollographql.execution/compositeResolver(kotlin/Function1): com.apollographql.apollo.execution/ExecutableSchema.Builder // com.apollographql.execution/compositeResolver|compositeResolver@com.apollographql.apollo.execution.ExecutableSchema.Builder(kotlin.Function1){}[0] final fun com.apollographql.execution/sandboxHtml(kotlin/String, kotlin/String): kotlin/String // com.apollographql.execution/sandboxHtml|sandboxHtml(kotlin.String;kotlin.String){}[0] diff --git a/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/errors.kt b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/errors.kt new file mode 100644 index 00000000..f6724f2f --- /dev/null +++ b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/errors.kt @@ -0,0 +1,54 @@ +package com.apollographql.execution + +import com.apollographql.apollo.api.Error +import com.apollographql.apollo.api.json.BufferedSinkJsonWriter +import com.apollographql.apollo.api.json.JsonWriter +import com.apollographql.apollo.api.json.writeAny +import com.apollographql.apollo.api.json.writeArray +import com.apollographql.apollo.api.json.writeObject +import okio.BufferedSink +import okio.Sink +import okio.buffer + +internal fun JsonWriter.writeError(error: Error) { + writeObject { + name("message") + value(error.message) + if (error.locations != null) { + name("locations") + writeArray { + error.locations!!.forEach { + writeObject { + name("line") + value(it.line) + name("column") + value(it.column) + } + } + } + } + if (error.path != null) { + name("path") + writeArray { + error.path!!.forEach { + when (it) { + is Int -> value(it) + is String -> value(it) + else -> error("path can only contain Int and Double (found '${it::class.simpleName}')") + } + } + } + } + if (error.extensions != null) { + name("extensions") + writeObject { + error.extensions!!.entries.forEach { + name(it.key) + writeAny(it.value) + } + } + } + } +} + +internal fun Sink.jsonWriter(): JsonWriter = BufferedSinkJsonWriter(if (this is BufferedSink) this else this.buffer()) diff --git a/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/GraphQLWsWebSocketHandler.kt b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/GraphQLWsWebSocketHandler.kt new file mode 100644 index 00000000..e765d5d6 --- /dev/null +++ b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/GraphQLWsWebSocketHandler.kt @@ -0,0 +1,307 @@ +package com.apollographql.execution.websocket + +import CurrentSubscription +import com.apollographql.apollo.annotations.ApolloInternal +import com.apollographql.apollo.api.Error +import com.apollographql.apollo.api.Error.* +import com.apollographql.apollo.api.ExecutionContext +import com.apollographql.apollo.api.Optional +import com.apollographql.apollo.api.json.* +import com.apollographql.apollo.execution.* +import com.apollographql.execution.jsonWriter +import com.apollographql.execution.writeError +import kotlinx.atomicfu.locks.reentrantLock +import kotlinx.atomicfu.locks.withLock +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.launch +import okio.Buffer +import okio.Sink + +/** + * A [WebSocketHandler] that implements https://github.com/enisdenjo/graphql-ws/blob/0c0eb499c3a0278c6d9cc799064f22c5d24d2f60/PROTOCOL.md + */ +class GraphQLWsWebSocketHandler( + private val executableSchema: ExecutableSchema, + private val scope: CoroutineScope, + private val executionContext: ExecutionContext, + private val sendMessage: suspend (WebSocketMessage) -> Unit, + private val connectionInitHandler: WsConnectionInitHandler = { WsConnectionInitAck }, +) : WebSocketHandler { + private val lock = reentrantLock() + private val activeSubscriptions = mutableMapOf() + private var isClosed: Boolean = false + private var initJob: Job? = null + + override fun handleMessage(message: WebSocketMessage) { + val clientMessage = when (message) { + is WebSocketBinaryMessage -> message.data.decodeToString() + is WebSocketTextMessage -> message.data + }.parseApolloWebsocketClientMessage() + + when (clientMessage) { + is Init -> { + initJob = lock.withLock { + scope.launch { + when(val result = connectionInitHandler.invoke(clientMessage.connectionParams)) { + is WsConnectionInitAck -> { + sendMessage(ConnectionAck.toWsMessage()) + } + is WsConnectionInitError -> { + sendMessage(ConnectionError(result.payload).toWsMessage()) + } + } + } + } + } + + is Subscribe -> { + val isActive = lock.withLock { + activeSubscriptions.containsKey(clientMessage.id) + } + if (isActive) { + scope.launch { + sendMessage(Error(id = clientMessage.id, error = Builder("Subscription ${clientMessage.id} is already active").build()).toWsMessage()) + } + return + } + + val flow = executableSchema.subscribe(clientMessage.request, executionContext + CurrentSubscription(clientMessage.id)) + + val job = scope.launch { + flow.collect { + when (it) { + is SubscriptionResponse -> { + sendMessage(Data(id = clientMessage.id, response = it.response).toWsMessage()) + } + + is SubscriptionError -> { + sendMessage(Error(id = clientMessage.id, error = it.errors.first()).toWsMessage()) + } + } + sendMessage(Complete(id = clientMessage.id).toWsMessage()) + } + sendMessage(Complete(id = clientMessage.id).toWsMessage()) + lock.withLock { + activeSubscriptions.remove(clientMessage.id)?.cancel() + } + } + + lock.withLock { + activeSubscriptions.put(clientMessage.id, job) + } + } + + is Complete -> { + lock.withLock { + activeSubscriptions.remove(clientMessage.id)?.cancel() + } + } + + is ParseError -> { + scope.launch { + sendMessage(Error(null, Builder("Cannot handle message (${clientMessage.message})").build()).toWsMessage()) + } + } + + Ping -> { + scope.launch { + sendMessage(Pong.toWsMessage()) + } + } + Pong -> { + + } + } + } + + fun close() { + lock.withLock { + if (isClosed) { + return + } + + activeSubscriptions.forEach { + it.value.cancel() + } + activeSubscriptions.clear() + + initJob?.cancel() + isClosed = true + } + } +} + +private sealed interface MessageResult + +private class ParseError( + val message: String, +) : MessageResult + +private sealed interface ClientMessage : MessageResult + +private class Init( + val connectionParams: Any?, +) : ClientMessage + +private class Subscribe( + val id: String, + val request: GraphQLRequest, +) : ClientMessage + +private class Complete( + val id: String, +) : ClientMessage, ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("complete") { + name("id") + value(id) + } + } +} + +private data object Ping : ClientMessage, ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("ping") + } +} + +private data object Pong : ClientMessage, ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("pong") + } +} + +private sealed interface ServerMessage { + fun serialize(sink: Sink) +} + +private fun Sink.writeMessage(type: String, block: (JsonWriter.() -> Unit)? = null) { + jsonWriter().apply { + writeObject { + name("type") + value(type) + block?.invoke(this) + } + flush() + } +} + +private data object ConnectionAck : ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("connection_ack") + } +} + +private class ConnectionError(private val payload: Optional) : ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("connection_error") { + if (payload is Optional.Present<*>) { + name("payload") + writeAny(payload.value) + } + } + } +} + +private class Data( + val id: String, + val response: GraphQLResponse, +) : ServerMessage { + override fun serialize(sink: Sink) { + sink.writeMessage("data") { + name("id") + value(id) + name("payload") + response.serialize(this) + } + } +} + +private class Error( + val id: String?, + val error: Error, +) : ServerMessage { + + override fun serialize(sink: Sink) { + sink.writeMessage("error") { + if (id != null) { + name("id") + value(id) + } + name("payload") + writeError(error) + } + } +} + +@OptIn(ApolloInternal::class) +private fun String.parseApolloWebsocketClientMessage(): MessageResult { + @Suppress("UNCHECKED_CAST") + val map = try { + Buffer().writeUtf8(this).jsonReader().readAny() as Map + } catch (e: Exception) { + return ParseError("Malformed Json: ${e.message}") + } + + val type = map["type"] + if (type == null) { + return ParseError("No 'type' found in $this") + } + if (type !is String) { + return ParseError("'type' must be a String in $this") + } + + when (type) { + "subscribe", "complete" -> { + val id = map["id"] + if (id == null) { + return ParseError("No 'id' found in $this") + } + + if (id !is String) { + return ParseError("'id' must be a String in $this") + } + + if (type == "subscribe") { + val payload = map["payload"] + if (payload == null) { + return ParseError("No 'payload' found in $this") + } + if (payload !is Map<*, *>) { + return ParseError("'payload' must be an Object in $this") + } + + @Suppress("UNCHECKED_CAST") + val request = (payload as Map).parseAsGraphQLRequest() + return request.fold( + onFailure = { ParseError("Cannot parse subscribe payload: '${it.message}'") }, + onSuccess = { Subscribe(id, request = it) } + ) + } else { + return Complete(id) + } + } + "ping" -> { + return Ping + } + "pong" -> { + return Pong + } + "connection_init" -> { + return Init(map["payload"]) + } + + else -> return ParseError("Unknown message type '$type'") + } +} + +private fun ServerMessage.toWsMessage(): WebSocketMessage { + return WebSocketTextMessage(Buffer().apply { serialize(this) }.readUtf8()) +} + +sealed interface WsConnectionInitResult +data object WsConnectionInitAck : WsConnectionInitResult +class WsConnectionInitError(val payload: Optional = Optional.absent()): WsConnectionInitResult + +typealias WsConnectionInitHandler = suspend (Any?) -> WsConnectionInitResult diff --git a/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/SubscriptionWebSocketHandler.kt b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/SubscriptionWebSocketHandler.kt index d60c3b55..fee1020f 100644 --- a/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/SubscriptionWebSocketHandler.kt +++ b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/SubscriptionWebSocketHandler.kt @@ -1,31 +1,21 @@ package com.apollographql.execution.websocket +import CurrentSubscription import com.apollographql.apollo.annotations.ApolloInternal import com.apollographql.apollo.api.Error import com.apollographql.apollo.api.ExecutionContext import com.apollographql.apollo.api.Optional -import com.apollographql.apollo.api.json.BufferedSinkJsonWriter -import com.apollographql.apollo.api.json.JsonWriter -import com.apollographql.apollo.api.json.jsonReader -import com.apollographql.apollo.api.json.readAny -import com.apollographql.apollo.api.json.writeAny -import com.apollographql.apollo.api.json.writeArray -import com.apollographql.apollo.api.json.writeObject -import com.apollographql.apollo.execution.ExecutableSchema -import com.apollographql.apollo.execution.GraphQLRequest -import com.apollographql.apollo.execution.GraphQLResponse -import com.apollographql.apollo.execution.SubscriptionError -import com.apollographql.apollo.execution.SubscriptionResponse -import com.apollographql.apollo.execution.parseAsGraphQLRequest +import com.apollographql.apollo.api.json.* +import com.apollographql.apollo.execution.* +import com.apollographql.execution.jsonWriter +import com.apollographql.execution.writeError import kotlinx.atomicfu.locks.reentrantLock import kotlinx.atomicfu.locks.withLock import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.launch import okio.Buffer -import okio.BufferedSink import okio.Sink -import okio.buffer /** * A [WebSocketHandler] that implements https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md @@ -135,38 +125,34 @@ class SubscriptionWebSocketHandler( } } +private sealed interface SubscriptionWebsocketClientMessageResult - -internal sealed interface SubscriptionWebsocketClientMessageResult - -internal class SubscriptionWebsocketClientMessageParseError internal constructor( +private class SubscriptionWebsocketClientMessageParseError internal constructor( val message: String, ) : SubscriptionWebsocketClientMessageResult -internal sealed interface SubscriptionWebsocketClientMessage : SubscriptionWebsocketClientMessageResult +private sealed interface SubscriptionWebsocketClientMessage : SubscriptionWebsocketClientMessageResult -internal class SubscriptionWebsocketInit( +private class SubscriptionWebsocketInit( val connectionParams: Any?, ) : SubscriptionWebsocketClientMessage -internal class SubscriptionWebsocketStart( +private class SubscriptionWebsocketStart( val id: String, val request: GraphQLRequest, ) : SubscriptionWebsocketClientMessage -internal class SubscriptionWebsocketStop( +private class SubscriptionWebsocketStop( val id: String, ) : SubscriptionWebsocketClientMessage -internal object SubscriptionWebsocketTerminate : SubscriptionWebsocketClientMessage +private object SubscriptionWebsocketTerminate : SubscriptionWebsocketClientMessage -internal sealed interface SubscriptionWebsocketServerMessage { +private sealed interface SubscriptionWebsocketServerMessage { fun serialize(sink: Sink) } -internal fun Sink.jsonWriter(): JsonWriter = BufferedSinkJsonWriter(if (this is BufferedSink) this else this.buffer()) - private fun Sink.writeMessage(type: String, block: (JsonWriter.() -> Unit)? = null) { jsonWriter().apply { writeObject { @@ -178,13 +164,13 @@ private fun Sink.writeMessage(type: String, block: (JsonWriter.() -> Unit)? = nu } } -internal data object SubscriptionWebsocketConnectionAck : SubscriptionWebsocketServerMessage { +private data object SubscriptionWebsocketConnectionAck : SubscriptionWebsocketServerMessage { override fun serialize(sink: Sink) { sink.writeMessage("connection_ack") } } -internal class SubscriptionWebsocketConnectionError(private val payload: Optional) : SubscriptionWebsocketServerMessage { +private class SubscriptionWebsocketConnectionError(private val payload: Optional) : SubscriptionWebsocketServerMessage { override fun serialize(sink: Sink) { sink.writeMessage("connection_error") { if (payload is Optional.Present<*>) { @@ -195,7 +181,7 @@ internal class SubscriptionWebsocketConnectionError(private val payload: Optiona } } -internal class SubscriptionWebsocketData( +private class SubscriptionWebsocketData( val id: String, val response: GraphQLResponse, ) : SubscriptionWebsocketServerMessage { @@ -209,7 +195,7 @@ internal class SubscriptionWebsocketData( } } -internal class SubscriptionWebsocketError( +private class SubscriptionWebsocketError( val id: String?, val error: Error, ) : SubscriptionWebsocketServerMessage { @@ -226,7 +212,7 @@ internal class SubscriptionWebsocketError( } } -internal class SubscriptionWebsocketComplete( +private class SubscriptionWebsocketComplete( val id: String, ) : SubscriptionWebsocketServerMessage { override fun serialize(sink: Sink) { @@ -238,7 +224,7 @@ internal class SubscriptionWebsocketComplete( } @OptIn(ApolloInternal::class) -internal fun String.parseApolloWebsocketClientMessage(): SubscriptionWebsocketClientMessageResult { +private fun String.parseApolloWebsocketClientMessage(): SubscriptionWebsocketClientMessageResult { @Suppress("UNCHECKED_CAST") val map = try { Buffer().writeUtf8(this).jsonReader().readAny() as Map @@ -307,52 +293,4 @@ class ConnectionInitError(val payload: Optional = Optional.absent()): Conn typealias ConnectionInitHandler = suspend (Any?) -> ConnectionInitResult -private class CurrentSubscription(val id: String) : ExecutionContext.Element { - override val key: ExecutionContext.Key = Key - - companion object Key : ExecutionContext.Key -} - -fun ExecutionContext.subscriptionId(): String = get(CurrentSubscription)?.id ?: error("Apollo: not executing a subscription") - -internal fun JsonWriter.writeError(error: Error) { - writeObject { - name("message") - value(error.message) - if (error.locations != null) { - name("locations") - writeArray { - error.locations!!.forEach { - writeObject { - name("line") - value(it.line) - name("column") - value(it.column) - } - } - } - } - if (error.path != null) { - name("path") - writeArray { - error.path!!.forEach { - when (it) { - is Int -> value(it) - is String -> value(it) - else -> error("path can only contain Int and Double (found '${it::class.simpleName}')") - } - } - } - } - if (error.extensions != null) { - name("extensions") - writeObject { - error.extensions!!.entries.forEach { - name(it.key) - writeAny(it.value) - } - } - } - } -} \ No newline at end of file diff --git a/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/subscription.kt b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/subscription.kt new file mode 100644 index 00000000..a3d50590 --- /dev/null +++ b/apollo-execution-runtime/src/commonMain/kotlin/com/apollographql/execution/websocket/subscription.kt @@ -0,0 +1,10 @@ +import com.apollographql.apollo.api.ExecutionContext + +internal class CurrentSubscription(val id: String) : ExecutionContext.Element { + + override val key: ExecutionContext.Key = Key + + companion object Key : ExecutionContext.Key +} + +internal fun ExecutionContext.subscriptionId(): String = get(CurrentSubscription)?.id ?: error("Apollo: not executing a subscription")