Skip to content

Add client env proto to spark connect client requests #51529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, Relation, SparkConnectServiceGrpc, SQL}
import org.apache.spark.sql.connect.SparkSession
Expand Down Expand Up @@ -80,6 +81,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(client.userId == "abc123")
}

test("Pass client env details in request") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
.retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "TestPolicy"))
.build()

val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()

// Failed the ssl handshake as the dummy server does not have any server credentials installed.
assertThrows[SparkException] {
client.analyze(request)
}

// Verify the request sent included these details
val env = request.getClientEnv()
assert(env.getSparkVersion == SPARK_VERSION)
assert(env.getScalaEnv.getScalaVersion == util.Properties.versionString)

}

// Use 0 to start the server at a random port
private def testClientConnection(serverPort: Int = 0)(
clientBuilder: Int => SparkConnectClient): Unit = {
Expand Down
62 changes: 61 additions & 1 deletion sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ message UserContext {
repeated google.protobuf.Any extensions = 999;
}

// Client Env is used to provide information about the client environment with some overlap to
// client_type but used for server side processed
message ClientEnv {
message ScalaEnv {
string scala_version = 1;
string java_version = 2;
}

message PythonEnv {
string python_version = 1;
}

string spark_version = 1;

oneof env {
ScalaEnv scala_env = 2;
PythonEnv python_env = 3;
}
}

// Request to perform plan analyze, optionally to explain the plan.
message AnalyzePlanRequest {
// (Required)
Expand Down Expand Up @@ -99,6 +119,10 @@ message AnalyzePlanRequest {
JsonToDDL json_to_ddl = 18;
}

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 19;

message Schema {
// (Required) The logical plan to be analyzed.
Plan plan = 1;
Expand Down Expand Up @@ -342,6 +366,10 @@ message ExecutePlanRequest {
// Tags cannot contain ',' character and cannot be empty strings.
// Used by Interrupt with interrupt.tag.
repeated string tags = 7;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 9;
}

// The response of a query, can be one or more for each request. Responses belonging to the
Expand Down Expand Up @@ -520,6 +548,10 @@ message ConfigRequest {
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 5;

message Operation {
oneof op_type {
Expand Down Expand Up @@ -616,7 +648,11 @@ message AddArtifactsRequest {
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 6;


// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 8;

// A chunk of an Artifact.
message ArtifactChunk {
// Data chunk.
Expand Down Expand Up @@ -727,6 +763,10 @@ message ArtifactStatusesRequest {
// The relative path of the file on the server's filesystem will be the same as the name of
// the provided artifact)
repeated string names = 4;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 6;
}

// Response to checking artifact statuses.
Expand Down Expand Up @@ -792,6 +832,10 @@ message InterruptRequest {
// if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation with this operation_id.
string operation_id = 6;
}

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 8;
}

// Next ID: 4
Expand Down Expand Up @@ -857,6 +901,10 @@ message ReattachExecuteRequest {
// that are far behind the latest returned response, so this can't be used to arbitrarily
// scroll back the cursor. If the response is no longer available, this will result in an error.
optional string last_response_id = 5;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 7;
}

message ReleaseExecuteRequest {
Expand Down Expand Up @@ -905,6 +953,10 @@ message ReleaseExecuteRequest {
ReleaseAll release_all = 5;
ReleaseUntil release_until = 6;
}

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 8;
}

// Next ID: 4
Expand Down Expand Up @@ -952,6 +1004,10 @@ message ReleaseSessionRequest {
// reconnecting to a released session. The client must ensure that any queries executed do not
// rely on the session state prior to its release.
bool allow_reconnect = 4;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 5;
}

// Next ID: 3
Expand Down Expand Up @@ -987,6 +1043,10 @@ message FetchErrorDetailsRequest {
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;

// Provides optional information about the client environment similar to client_type but is
// intended for server side processing
optional ClientEnv client_env = 6;
}

// Next ID: 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class ArtifactManager(
.newBuilder()
.setUserContext(clientConfig.userContext)
.setClientType(clientConfig.userAgent)
.setClientEnv(clientConfig.clientEnv)
.setSessionId(sessionId)
.addAllNames(Arrays.asList(artifactName))
.build()
Expand Down Expand Up @@ -318,6 +319,7 @@ class ArtifactManager(
.newBuilder()
.setUserContext(clientConfig.userContext)
.setClientType(clientConfig.userAgent)
.setClientEnv(clientConfig.clientEnv)
.setSessionId(sessionId)
artifacts.foreach { artifact =>
val in = new CheckedInputStream(artifact.storage.stream, new CRC32)
Expand Down Expand Up @@ -374,6 +376,7 @@ class ArtifactManager(
.newBuilder()
.setUserContext(clientConfig.userContext)
.setClientType(clientConfig.userAgent)
.setClientEnv(clientConfig.clientEnv)
.setSessionId(sessionId)

val in = new CheckedInputStream(artifact.storage.stream, new CRC32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
grpcExceptionConverter.convertIterator[ExecutePlanResponse](
request.getSessionId,
request.getUserContext,
Expand All @@ -47,7 +48,8 @@ private[connect] class CustomSparkConnectBlockingStub(
r => {
stubState.responseValidator.wrapIterator(
CloseableIterator(stub.executePlan(r).asScala))
}))
}),
request.getClientEnv)
}
}

Expand All @@ -56,22 +58,25 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
grpcExceptionConverter.convertIterator[ExecutePlanResponse](
request.getSessionId,
request.getUserContext,
request.getClientType,
stubState.responseValidator.wrapIterator(
// ExecutePlanResponseReattachableIterator does all retries by itself, don't wrap it here
new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler)))
new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler)),
request.getClientEnv)
}
}

def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = {
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
retryHandler.retry {
stubState.responseValidator.verifyResponse {
stub.analyzePlan(request)
Expand All @@ -84,7 +89,8 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
retryHandler.retry {
stubState.responseValidator.verifyResponse {
stub.config(request)
Expand All @@ -97,7 +103,8 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
retryHandler.retry {
stubState.responseValidator.verifyResponse {
stub.interrupt(request)
Expand All @@ -110,7 +117,8 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
retryHandler.retry {
stubState.responseValidator.verifyResponse {
stub.releaseSession(request)
Expand All @@ -123,7 +131,8 @@ private[connect] class CustomSparkConnectBlockingStub(
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
request.getClientType,
request.getClientEnv) {
retryHandler.retry {
stubState.responseValidator.verifyResponse {
stub.artifactStatus(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ class ExecutePlanResponseReattachableIterator(
reattach.setClientType(initialRequest.getClientType)
}

if (initialRequest.hasClientEnv) {
reattach.setClientEnv(initialRequest.getClientEnv)
}

if (lastReturnedResponseId.isDefined) {
reattach.setLastResponseId(lastReturnedResponseId.get)
}
Expand All @@ -317,6 +321,10 @@ class ExecutePlanResponseReattachableIterator(
release.setClientType(initialRequest.getClientType)
}

if (initialRequest.hasClientEnv) {
release.setClientEnv(initialRequest.getClientEnv)
}

untilResponseId match {
case None =>
release.setReleaseAll(proto.ReleaseExecuteRequest.ReleaseAll.newBuilder().build())
Expand Down
Loading