diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index a41ea344cbd4c..5b25bcdffca8f 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, Relation, SparkConnectServiceGrpc, SQL} import org.apache.spark.sql.connect.SparkSession @@ -80,6 +81,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(client.userId == "abc123") } + test("Pass client env details in request") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true") + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "TestPolicy")) + .build() + + val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build() + + // Failed the ssl handshake as the dummy server does not have any server credentials installed. + assertThrows[SparkException] { + client.analyze(request) + } + + // Verify the request sent included these details + val env = request.getClientEnv() + assert(env.getSparkVersion == SPARK_VERSION) + assert(env.getScalaEnv.getScalaVersion == util.Properties.versionString) + + } + // Use 0 to start the server at a random port private def testClientConnection(serverPort: Int = 0)( clientBuilder: Int => SparkConnectClient): Unit = { diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto b/sql/connect/common/src/main/protobuf/spark/connect/base.proto index 7f317defd47b5..118cffe53ff75 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto @@ -58,6 +58,26 @@ message UserContext { repeated google.protobuf.Any extensions = 999; } +// Client Env is used to provide information about the client environment with some overlap to +// client_type but used for server side processed +message ClientEnv { + message ScalaEnv { + string scala_version = 1; + string java_version = 2; + } + + message PythonEnv { + string python_version = 1; + } + + string spark_version = 1; + + oneof env { + ScalaEnv scala_env = 2; + PythonEnv python_env = 3; + } +} + // Request to perform plan analyze, optionally to explain the plan. message AnalyzePlanRequest { // (Required) @@ -99,6 +119,10 @@ message AnalyzePlanRequest { JsonToDDL json_to_ddl = 18; } + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 19; + message Schema { // (Required) The logical plan to be analyzed. Plan plan = 1; @@ -342,6 +366,10 @@ message ExecutePlanRequest { // Tags cannot contain ',' character and cannot be empty strings. // Used by Interrupt with interrupt.tag. repeated string tags = 7; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 9; } // The response of a query, can be one or more for each request. Responses belonging to the @@ -520,6 +548,10 @@ message ConfigRequest { // can be used for language or version specific information and is only intended for // logging purposes and will not be interpreted by the server. optional string client_type = 4; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 5; message Operation { oneof op_type { @@ -616,7 +648,11 @@ message AddArtifactsRequest { // can be used for language or version specific information and is only intended for // logging purposes and will not be interpreted by the server. optional string client_type = 6; - + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 8; + // A chunk of an Artifact. message ArtifactChunk { // Data chunk. @@ -727,6 +763,10 @@ message ArtifactStatusesRequest { // The relative path of the file on the server's filesystem will be the same as the name of // the provided artifact) repeated string names = 4; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 6; } // Response to checking artifact statuses. @@ -792,6 +832,10 @@ message InterruptRequest { // if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation with this operation_id. string operation_id = 6; } + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 8; } // Next ID: 4 @@ -857,6 +901,10 @@ message ReattachExecuteRequest { // that are far behind the latest returned response, so this can't be used to arbitrarily // scroll back the cursor. If the response is no longer available, this will result in an error. optional string last_response_id = 5; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 7; } message ReleaseExecuteRequest { @@ -905,6 +953,10 @@ message ReleaseExecuteRequest { ReleaseAll release_all = 5; ReleaseUntil release_until = 6; } + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 8; } // Next ID: 4 @@ -952,6 +1004,10 @@ message ReleaseSessionRequest { // reconnecting to a released session. The client must ensure that any queries executed do not // rely on the session state prior to its release. bool allow_reconnect = 4; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 5; } // Next ID: 3 @@ -987,6 +1043,10 @@ message FetchErrorDetailsRequest { // can be used for language or version specific information and is only intended for // logging purposes and will not be interpreted by the server. optional string client_type = 4; + + // Provides optional information about the client environment similar to client_type but is + // intended for server side processing + optional ClientEnv client_env = 6; } // Next ID: 5 diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index e9411dc3db61b..28dc36b34838b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -169,6 +169,7 @@ class ArtifactManager( .newBuilder() .setUserContext(clientConfig.userContext) .setClientType(clientConfig.userAgent) + .setClientEnv(clientConfig.clientEnv) .setSessionId(sessionId) .addAllNames(Arrays.asList(artifactName)) .build() @@ -318,6 +319,7 @@ class ArtifactManager( .newBuilder() .setUserContext(clientConfig.userContext) .setClientType(clientConfig.userAgent) + .setClientEnv(clientConfig.clientEnv) .setSessionId(sessionId) artifacts.foreach { artifact => val in = new CheckedInputStream(artifact.storage.stream, new CRC32) @@ -374,6 +376,7 @@ class ArtifactManager( .newBuilder() .setUserContext(clientConfig.userContext) .setClientType(clientConfig.userAgent) + .setClientEnv(clientConfig.clientEnv) .setSessionId(sessionId) val in = new CheckedInputStream(artifact.storage.stream, new CRC32) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index d7867229248b8..76ed778c7eb5c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -37,7 +37,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { grpcExceptionConverter.convertIterator[ExecutePlanResponse]( request.getSessionId, request.getUserContext, @@ -47,7 +48,8 @@ private[connect] class CustomSparkConnectBlockingStub( r => { stubState.responseValidator.wrapIterator( CloseableIterator(stub.executePlan(r).asScala)) - })) + }), + request.getClientEnv) } } @@ -56,14 +58,16 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { grpcExceptionConverter.convertIterator[ExecutePlanResponse]( request.getSessionId, request.getUserContext, request.getClientType, stubState.responseValidator.wrapIterator( // ExecutePlanResponseReattachableIterator does all retries by itself, don't wrap it here - new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler))) + new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler)), + request.getClientEnv) } } @@ -71,7 +75,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { retryHandler.retry { stubState.responseValidator.verifyResponse { stub.analyzePlan(request) @@ -84,7 +89,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { retryHandler.retry { stubState.responseValidator.verifyResponse { stub.config(request) @@ -97,7 +103,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { retryHandler.retry { stubState.responseValidator.verifyResponse { stub.interrupt(request) @@ -110,7 +117,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { retryHandler.retry { stubState.responseValidator.verifyResponse { stub.releaseSession(request) @@ -123,7 +131,8 @@ private[connect] class CustomSparkConnectBlockingStub( grpcExceptionConverter.convert( request.getSessionId, request.getUserContext, - request.getClientType) { + request.getClientType, + request.getClientEnv) { retryHandler.retry { stubState.responseValidator.verifyResponse { stub.artifactStatus(request) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index f3c13c9c2c4d8..4252f03305d30 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -300,6 +300,10 @@ class ExecutePlanResponseReattachableIterator( reattach.setClientType(initialRequest.getClientType) } + if (initialRequest.hasClientEnv) { + reattach.setClientEnv(initialRequest.getClientEnv) + } + if (lastReturnedResponseId.isDefined) { reattach.setLastResponseId(lastReturnedResponseId.get) } @@ -317,6 +321,10 @@ class ExecutePlanResponseReattachableIterator( release.setClientType(initialRequest.getClientType) } + if (initialRequest.hasClientEnv) { + release.setClientEnv(initialRequest.getClientEnv) + } + untilResponseId match { case None => release.setReleaseAll(proto.ReleaseExecuteRequest.ReleaseAll.newBuilder().build()) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index d3dae47f4c471..d0db61a25e457 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -28,7 +28,7 @@ import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} -import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, SparkConnectServiceGrpc, UserContext} +import org.apache.spark.connect.proto.{ClientEnv, FetchErrorDetailsRequest, FetchErrorDetailsResponse, SparkConnectServiceGrpc, UserContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException} @@ -53,12 +53,13 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) - def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = { + def convert[T](sessionId: String, userContext: UserContext, clientType: String, + clientEnv: ClientEnv)(f: => T): T = { try { f } catch { case e: StatusRuntimeException => - throw toThrowable(e, sessionId, userContext, clientType) + throw toThrowable(e, sessionId, userContext, clientType, clientEnv) } } @@ -66,25 +67,26 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo sessionId: String, userContext: UserContext, clientType: String, - iter: CloseableIterator[T]): CloseableIterator[T] = { + iter: CloseableIterator[T], + clientEnv: ClientEnv): CloseableIterator[T] = { new WrappedCloseableIterator[T] { override def innerIterator: Iterator[T] = iter override def hasNext: Boolean = { - convert(sessionId, userContext, clientType) { + convert(sessionId, userContext, clientType, clientEnv) { iter.hasNext } } override def next(): T = { - convert(sessionId, userContext, clientType) { + convert(sessionId, userContext, clientType, clientEnv) { iter.next() } } override def close(): Unit = { - convert(sessionId, userContext, clientType) { + convert(sessionId, userContext, clientType, clientEnv) { iter.close() } } @@ -99,7 +101,8 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo info: ErrorInfo, sessionId: String, userContext: UserContext, - clientType: String): Option[Throwable] = { + clientType: String, + clientEnv: ClientEnv): Option[Throwable] = { val errorId = info.getMetadataOrDefault("errorId", null) if (errorId == null) { logWarning("Unable to fetch enriched error since errorId is missing") @@ -114,6 +117,7 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo .setErrorId(errorId) .setUserContext(userContext) .setClientType(clientType) + .setClientEnv(clientEnv) .build()) if (!errorDetailsResponse.hasRootErrorIdx) { @@ -136,7 +140,8 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo ex: StatusRuntimeException, sessionId: String, userContext: UserContext, - clientType: String): Throwable = { + clientType: String, + clientEnv: ClientEnv): Throwable = { val status = StatusProto.fromThrowable(ex) // Extract the ErrorInfo from the StatusProto, if present. @@ -147,7 +152,7 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo if (errorInfoOpt.isDefined) { // If ErrorInfo is found, try to fetch enriched error details by an additional RPC. val enrichedErrorOpt = - fetchEnrichedError(errorInfoOpt.get, sessionId, userContext, clientType) + fetchEnrichedError(errorInfoOpt.get, sessionId, userContext, clientType, clientEnv) if (enrichedErrorOpt.isDefined) { return enrichedErrorOpt.get } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index e844237a3bb44..7fcdf9feb1e25 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -51,6 +51,8 @@ private[sql] class SparkConnectClient( private[client] def userAgent: String = configuration.userAgent + private[client] def clientEnv: proto.ClientEnv = configuration.clientEnv + /** * Placeholder method. * @return @@ -136,6 +138,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) .addAllTags(tags.get.toSeq.asJava) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) operationId.foreach { opId => @@ -163,6 +166,7 @@ private[sql] class SparkConnectClient( .setOperation(operation) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) .setUserContext(userContext) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) bstub.config(request.build()) @@ -252,6 +256,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) analyze(request.build()) } @@ -262,6 +267,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) bstub.interrupt(request.build()) @@ -273,6 +279,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG) .setOperationTag(tag) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) @@ -285,6 +292,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID) .setOperationId(id) serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session)) @@ -297,6 +305,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .setClientEnv(clientEnv) bstub.releaseSession(request.build()) } @@ -702,12 +711,14 @@ object SparkConnectClient { def build(): SparkConnectClient = _configuration.toSparkConnectClient } + def javaVersion: String = System.getProperty("java.version").split("_")(0) + /** * Appends the Spark, Scala & JVM version, and the used OS to the user-provided user agent. */ private def genUserAgent(value: String): String = { val scalaVersion = Properties.versionNumberString - val jvmVersion = System.getProperty("java.version").split("_")(0) + val jvmVersion = javaVersion val osName = { val os = System.getProperty("os.name").toLowerCase(Locale.ROOT) if (os.contains("mac")) "darwin" @@ -748,6 +759,19 @@ object SparkConnectClient { private def isLocal = host.equals("localhost") + lazy val clientEnv: proto.ClientEnv = { + val builder = proto.ClientEnv.newBuilder() + .setSparkVersion(SPARK_VERSION) + // if the map is empty then this will not be set since it is defined as optional + .setScalaEnv( + proto.ClientEnv.ScalaEnv.newBuilder() + .setScalaVersion(Properties.versionNumberString) + .setJavaVersion(SparkConnectClient.javaVersion) + .build() + ) + builder.build() + } + def userContext: proto.UserContext = { val builder = proto.UserContext.newBuilder() if (userId != null) {