diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala index 4402dde04f3c8..e0c7beb43001d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala @@ -28,10 +28,7 @@ import org.apache.spark.sql.pipelines.graph.GraphRegistrationContext * PipelinesHandler when CreateDataflowGraph is called, and the PipelinesHandler also supports * attaching flows/datasets to a graph. */ -// TODO(SPARK-51727): Currently DataflowGraphRegistry is a singleton, but it should instead be -// scoped to a single SparkSession for proper isolation between pipelines that are run on the -// same cluster. -object DataflowGraphRegistry { +class DataflowGraphRegistry { private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]() @@ -55,7 +52,7 @@ object DataflowGraphRegistry { /** Retrieves the graph for a given id, and throws if the id could not be found. */ def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext = - DataflowGraphRegistry.getDataflowGraph(dataflowGraphId).getOrElse { + getDataflowGraph(dataflowGraphId).getOrElse { throw new SparkException( errorClass = "DATAFLOW_GRAPH_NOT_FOUND", messageParameters = Map("graphId" -> dataflowGraphId), diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 7f92aa13944c3..9ecc22cda13f0 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.pipelines.Language.Python @@ -68,7 +67,7 @@ private[connect] object PipelinesHandler extends Logging { cmd.getCommandTypeCase match { case proto.PipelineCommand.CommandTypeCase.CREATE_DATAFLOW_GRAPH => val createdGraphId = - createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder.session) + createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder) PipelineCommandResult .newBuilder() .setCreateDataflowGraphResult( @@ -78,15 +77,16 @@ private[connect] object PipelinesHandler extends Logging { .build() case proto.PipelineCommand.CommandTypeCase.DROP_DATAFLOW_GRAPH => logInfo(s"Drop pipeline cmd received: $cmd") - DataflowGraphRegistry.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) + sessionHolder.dataflowGraphRegistry + .dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET => logInfo(s"Define pipelines dataset cmd received: $cmd") - defineDataset(cmd.getDefineDataset, sessionHolder.session) + defineDataset(cmd.getDefineDataset, sessionHolder) defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW => logInfo(s"Define pipelines flow cmd received: $cmd") - defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder.session) + defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder) defaultResponse case proto.PipelineCommand.CommandTypeCase.START_RUN => logInfo(s"Start pipeline cmd received: $cmd") @@ -94,7 +94,7 @@ private[connect] object PipelinesHandler extends Logging { defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_SQL_GRAPH_ELEMENTS => logInfo(s"Register sql datasets cmd received: $cmd") - defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder.session) + defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder) defaultResponse case other => throw new UnsupportedOperationException(s"$other not supported") } @@ -102,24 +102,24 @@ private[connect] object PipelinesHandler extends Logging { private def createDataflowGraph( cmd: proto.PipelineCommand.CreateDataflowGraph, - spark: SparkSession): String = { + sessionHolder: SessionHolder): String = { val defaultCatalog = Option .when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog) .getOrElse { logInfo(s"No default catalog was supplied. Falling back to the current catalog.") - spark.catalog.currentCatalog() + sessionHolder.session.catalog.currentCatalog() } val defaultDatabase = Option .when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase) .getOrElse { logInfo(s"No default database was supplied. Falling back to the current database.") - spark.catalog.currentDatabase + sessionHolder.session.catalog.currentDatabase } val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap - DataflowGraphRegistry.createDataflowGraph( + sessionHolder.dataflowGraphRegistry.createDataflowGraph( defaultCatalog = defaultCatalog, defaultDatabase = defaultDatabase, defaultSqlConf = defaultSqlConf) @@ -127,24 +127,31 @@ private[connect] object PipelinesHandler extends Logging { private def defineSqlGraphElements( cmd: proto.PipelineCommand.DefineSqlGraphElements, - session: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry) - sqlGraphElementRegistrationContext.processSqlFile(cmd.getSqlText, cmd.getSqlFilePath, session) + sqlGraphElementRegistrationContext.processSqlFile( + cmd.getSqlText, + cmd.getSqlFilePath, + sessionHolder.session) } private def defineDataset( dataset: proto.PipelineCommand.DefineDataset, - sparkSession: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = dataset.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) dataset.getDatasetType match { case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE => val tableIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession) + GraphIdentifierManager.parseTableIdentifier( + dataset.getDatasetName, + sessionHolder.session) graphElementRegistry.registerTable( Table( identifier = tableIdentifier, @@ -165,7 +172,9 @@ private[connect] object PipelinesHandler extends Logging { isStreamingTable = dataset.getDatasetType == proto.DatasetType.TABLE)) case proto.DatasetType.TEMPORARY_VIEW => val viewIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession) + GraphIdentifierManager.parseTableIdentifier( + dataset.getDatasetName, + sessionHolder.session) graphElementRegistry.registerView( TemporaryView( @@ -184,14 +193,15 @@ private[connect] object PipelinesHandler extends Logging { private def defineFlow( flow: proto.PipelineCommand.DefineFlow, transformRelationFunc: Relation => LogicalPlan, - sparkSession: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = flow.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName val flowIdentifier = GraphIdentifierManager - .parseTableIdentifier(name = flow.getFlowName, spark = sparkSession) + .parseTableIdentifier(name = flow.getFlowName, spark = sessionHolder.session) // If the flow is not an implicit flow (i.e. one defined as part of dataset creation), then // it must be a single-part identifier. @@ -205,7 +215,7 @@ private[connect] object PipelinesHandler extends Logging { new UnresolvedFlow( identifier = flowIdentifier, destinationIdentifier = GraphIdentifierManager - .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sparkSession), + .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session), func = FlowAnalysis.createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)), sqlConf = flow.getSqlConfMap.asScala.toMap, @@ -224,7 +234,8 @@ private[connect] object PipelinesHandler extends Logging { responseObserver: StreamObserver[ExecutePlanResponse], sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder) // We will use this variable to store the run failure event if it occurs. This will be set diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index ada322fd859c5..1b43ea529ec02 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.MLCache +import org.apache.spark.sql.connect.pipelines.DataflowGraphRegistry import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} @@ -125,6 +126,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private lazy val pipelineExecutions = new ConcurrentHashMap[String, PipelineUpdateContext]() + // Registry for dataflow graphs specific to this session + private[connect] lazy val dataflowGraphRegistry = new DataflowGraphRegistry() + // Handles Python process clean up for streaming queries. Initialized on first use in a query. private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) @@ -320,6 +324,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Stops all pipeline execution and clears the pipeline execution cache removeAllPipelineExecutions() + // Clean up dataflow graphs + dataflowGraphRegistry.dropAllDataflowGraphs() + // if there is a server side listener, clean up related resources if (streamingServersideListenerHolder.isServerSideListenerRegistered) { streamingServersideListenerHolder.cleanUp() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index a9e8f9b5245b6..1bc2172d86e55 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.pipelines import java.io.{BufferedReader, InputStreamReader} import java.nio.charset.StandardCharsets import java.nio.file.Paths +import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer @@ -28,6 +29,7 @@ import scala.util.Try import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.pipelines.graph.DataflowGraph import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} @@ -42,6 +44,8 @@ class PythonPipelineSuite def buildGraph(pythonText: String): DataflowGraph = { val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + // create a unique identifier to allow identifying the session and dataflow graph + val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = s""" |from pyspark.sql import SparkSession @@ -57,6 +61,7 @@ class PythonPipelineSuite |spark = SparkSession.builder \\ | .remote("sc://localhost:$serverPort") \\ | .config("spark.connect.grpc.channel.timeout", "5s") \\ + | .config("spark.custom.identifier", "$customSessionIdentifier") \\ | .create() | |dataflow_graph_id = create_dataflow_graph( @@ -78,8 +83,17 @@ class PythonPipelineSuite throw new RuntimeException( s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } + val activeSessions = SparkConnectService.sessionManager.listActiveSessions - val dataflowGraphContexts = DataflowGraphRegistry.getAllDataflowGraphs + // get the session holder by finding the session with the custom UUID set in the conf + val sessionHolder = activeSessions + .map(info => SparkConnectService.sessionManager.getIsolatedSession(info.key, None)) + .find(_.session.conf.get("spark.custom.identifier") == customSessionIdentifier) + .getOrElse( + throw new RuntimeException(s"Session with identifier $customSessionIdentifier not found")) + + // get all dataflow graphs from the session holder + val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(dataflowGraphContexts.size == 1) dataflowGraphContexts.head.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 3b200c3d08aca..ef5da0c014ee1 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.connect.pipelines +import java.util.UUID + import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{DatasetType, Expression, PipelineCommand, Relation, UnresolvedTableValuedFunction} import org.apache.spark.connect.proto.PipelineCommand.{DefineDataset, DefineFlow} import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} class SparkDeclarativePipelinesServerSuite extends SparkDeclarativePipelinesServerTest @@ -41,8 +44,7 @@ class SparkDeclarativePipelinesServerSuite .newBuilder() .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) assert(definition.defaultDatabase == "test_db") } } @@ -115,8 +117,7 @@ class SparkDeclarativePipelinesServerSuite |""".stripMargin) val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) val graph = definition.toDataflowGraph.resolve() @@ -161,8 +162,7 @@ class SparkDeclarativePipelinesServerSuite } val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) registerPipelineDatasets(pipeline) val graph = definition.toDataflowGraph @@ -251,4 +251,239 @@ class SparkDeclarativePipelinesServerSuite assert(spark.table("spark_catalog.other.tableD").count() == 5) } } + + test("dataflow graphs are session-specific") { + withRawBlockingStub { implicit stub => + // Create a dataflow graph in the default session + val graphId1 = createDataflowGraph + + // Register a dataset in the default session + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId1) + .setDatasetName("session1_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify the graph exists in the default session + assert(getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs.size == 1) + } + + // Create a second session with different user/session ID + val newSessionId = UUID.randomUUID().toString + val newSessionUserId = "session2_user" + + withRawBlockingStub { implicit stub => + // Override the test context to use different session + val newSessionExecuteRequest = buildExecutePlanRequest( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("default") + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(newSessionUserId) + .build()) + .setSessionId(newSessionId) + .build() + + val response = stub.executePlan(newSessionExecuteRequest) + val graphId2 = + response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + + // Register a different dataset in second session + val session2DefineRequest = buildExecutePlanRequest( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId2) + .setDatasetName("session2_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(newSessionUserId) + .build()) + .setSessionId(newSessionId) + .build() + + stub.executePlan(session2DefineRequest).next() + + // Verify session isolation - each session should only see its own graphs + val newSessionHolder = SparkConnectService.sessionManager + .getIsolatedSession(SessionKey(newSessionUserId, newSessionId), None) + + val defaultSessionGraphs = + getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + + assert(defaultSessionGraphs.size == 1) + assert(newSessionGraphs.size == 1) + + assert( + defaultSessionGraphs.head.toDataflowGraph.tables + .exists(_.identifier.table == "session1_table"), + "Session 1 should have its own table") + assert( + newSessionGraphs.head.toDataflowGraph.tables + .exists(_.identifier.table == "session2_table"), + "Session 2 should have its own table") + } + } + + test("dataflow graphs are cleaned up when session is closed") { + val testUserId = "test_user" + val testSessionId = UUID.randomUUID().toString + + // Create a session and dataflow graph + withRawBlockingStub { implicit stub => + val createGraphRequest = buildExecutePlanRequest( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("default") + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(testUserId) + .build()) + .setSessionId(testSessionId) + .build() + + val response = stub.executePlan(createGraphRequest) + val graphId = + response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + + // Register a dataset + val defineRequest = buildExecutePlanRequest( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId) + .setDatasetName("test_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(testUserId) + .build()) + .setSessionId(testSessionId) + .build() + + stub.executePlan(defineRequest).next() + + // Verify the graph exists + val sessionHolder = SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) + .get + + val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + assert(graphsBefore.size == 1) + + // Close the session + SparkConnectService.sessionManager.closeSession(SessionKey(testUserId, testSessionId)) + + // Verify the session is no longer available + val sessionAfterClose = SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) + + assert(sessionAfterClose.isEmpty, "Session should be cleaned up after close") + // Verify the graph is removed + val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + assert(graphsAfter.isEmpty, "Graph should be removed after session close") + } + } + + test("multiple dataflow graphs can exist in the same session") { + withRawBlockingStub { implicit stub => + // Create two dataflow graphs in the same session + val graphId1 = createDataflowGraph + val graphId2 = createDataflowGraph + + // Register datasets in both graphs + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId1) + .setDatasetName("graph1_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId2) + .setDatasetName("graph2_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify both graphs exist in the session + val sessionHolder = getDefaultSessionHolder + val graph1 = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId1) + val graph2 = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId2) + // Check that both graphs have their datasets registered + assert(graph1.toDataflowGraph.tables.exists(_.identifier.table == "graph1_table")) + assert(graph2.toDataflowGraph.tables.exists(_.identifier.table == "graph2_table")) + } + } + + test("dropping a dataflow graph removes it from session") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + + // Register a dataset + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId) + .setDatasetName("test_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify the graph exists + val sessionHolder = getDefaultSessionHolder + val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + assert(graphsBefore.size == 1) + + // Drop the graph + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDropDataflowGraph(PipelineCommand.DropDataflowGraph + .newBuilder() + .setDataflowGraphId(graphId)) + .build())) + + // Verify the graph is removed + val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + assert(graphsAfter.isEmpty, "Graph should be removed after drop") + } + } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index 003fd30b6075a..a31883677f92a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -23,7 +23,7 @@ import org.apache.spark.connect.{proto => sc} import org.apache.spark.connect.proto.{PipelineCommand, PipelineEvent} import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkConnectTestUtils} import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} +import org.apache.spark.sql.connect.service.{SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.PipelineTest class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { @@ -31,12 +31,20 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { SparkConnectService.sessionManager .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .foreach(_.removeAllPipelineExecutions()) - DataflowGraphRegistry.dropAllDataflowGraphs() + .foreach(s => { + s.removeAllPipelineExecutions() + s.dataflowGraphRegistry.dropAllDataflowGraphs() + }) PipelineTest.cleanupMetastore(spark) super.afterEach() } + // Helper method to get the session holder + protected def getDefaultSessionHolder: SessionHolder = { + SparkConnectService.sessionManager + .getIsolatedSession(SessionKey(defaultUserId, defaultSessionId), None) + } + def buildPlanFromPipelineCommand(command: sc.PipelineCommand): sc.Plan = { sc.Plan .newBuilder()