Skip to content

[SPARK-52432][SDP][SQL] Scope DataflowGraphRegistry to Session #51544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ import org.apache.spark.sql.pipelines.graph.GraphRegistrationContext
* PipelinesHandler when CreateDataflowGraph is called, and the PipelinesHandler also supports
* attaching flows/datasets to a graph.
*/
// TODO(SPARK-51727): Currently DataflowGraphRegistry is a singleton, but it should instead be
// scoped to a single SparkSession for proper isolation between pipelines that are run on the
// same cluster.
object DataflowGraphRegistry {
class DataflowGraphRegistry {

private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]()

Expand All @@ -55,7 +52,7 @@ object DataflowGraphRegistry {

/** Retrieves the graph for a given id, and throws if the id could not be found. */
def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext =
DataflowGraphRegistry.getDataflowGraph(dataflowGraphId).getOrElse {
getDataflowGraph(dataflowGraphId).getOrElse {
throw new SparkException(
errorClass = "DATAFLOW_GRAPH_NOT_FOUND",
messageParameters = Map("graphId" -> dataflowGraphId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.pipelines.Language.Python
Expand Down Expand Up @@ -68,7 +67,7 @@ private[connect] object PipelinesHandler extends Logging {
cmd.getCommandTypeCase match {
case proto.PipelineCommand.CommandTypeCase.CREATE_DATAFLOW_GRAPH =>
val createdGraphId =
createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder.session)
createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder)
PipelineCommandResult
.newBuilder()
.setCreateDataflowGraphResult(
Expand All @@ -78,73 +77,78 @@ private[connect] object PipelinesHandler extends Logging {
.build()
case proto.PipelineCommand.CommandTypeCase.DROP_DATAFLOW_GRAPH =>
logInfo(s"Drop pipeline cmd received: $cmd")
DataflowGraphRegistry.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
sessionHolder.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
defaultResponse
case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET =>
logInfo(s"Define pipelines dataset cmd received: $cmd")
defineDataset(cmd.getDefineDataset, sessionHolder.session)
defineDataset(cmd.getDefineDataset, sessionHolder)
defaultResponse
case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW =>
logInfo(s"Define pipelines flow cmd received: $cmd")
defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder.session)
defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder)
defaultResponse
case proto.PipelineCommand.CommandTypeCase.START_RUN =>
logInfo(s"Start pipeline cmd received: $cmd")
startRun(cmd.getStartRun, responseObserver, sessionHolder)
defaultResponse
case proto.PipelineCommand.CommandTypeCase.DEFINE_SQL_GRAPH_ELEMENTS =>
logInfo(s"Register sql datasets cmd received: $cmd")
defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder.session)
defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder)
defaultResponse
case other => throw new UnsupportedOperationException(s"$other not supported")
}
}

private def createDataflowGraph(
cmd: proto.PipelineCommand.CreateDataflowGraph,
spark: SparkSession): String = {
sessionHolder: SessionHolder): String = {
val defaultCatalog = Option
.when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog)
.getOrElse {
logInfo(s"No default catalog was supplied. Falling back to the current catalog.")
spark.catalog.currentCatalog()
sessionHolder.session.catalog.currentCatalog()
}

val defaultDatabase = Option
.when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase)
.getOrElse {
logInfo(s"No default database was supplied. Falling back to the current database.")
spark.catalog.currentDatabase
sessionHolder.session.catalog.currentDatabase
}

val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap

DataflowGraphRegistry.createDataflowGraph(
sessionHolder.createDataflowGraph(
defaultCatalog = defaultCatalog,
defaultDatabase = defaultDatabase,
defaultSqlConf = defaultSqlConf)
}

private def defineSqlGraphElements(
cmd: proto.PipelineCommand.DefineSqlGraphElements,
session: SparkSession): Unit = {
sessionHolder: SessionHolder): Unit = {
val dataflowGraphId = cmd.getDataflowGraphId

val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId)
val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry)
sqlGraphElementRegistrationContext.processSqlFile(cmd.getSqlText, cmd.getSqlFilePath, session)
sqlGraphElementRegistrationContext.processSqlFile(
cmd.getSqlText,
cmd.getSqlFilePath,
sessionHolder.session)
}

private def defineDataset(
dataset: proto.PipelineCommand.DefineDataset,
sparkSession: SparkSession): Unit = {
sessionHolder: SessionHolder): Unit = {
val dataflowGraphId = dataset.getDataflowGraphId
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId)

dataset.getDatasetType match {
case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE =>
val tableIdentifier =
GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession)
GraphIdentifierManager.parseTableIdentifier(
dataset.getDatasetName,
sessionHolder.session)
graphElementRegistry.registerTable(
Table(
identifier = tableIdentifier,
Expand All @@ -165,7 +169,9 @@ private[connect] object PipelinesHandler extends Logging {
isStreamingTable = dataset.getDatasetType == proto.DatasetType.TABLE))
case proto.DatasetType.TEMPORARY_VIEW =>
val viewIdentifier =
GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession)
GraphIdentifierManager.parseTableIdentifier(
dataset.getDatasetName,
sessionHolder.session)

graphElementRegistry.registerView(
TemporaryView(
Expand All @@ -184,14 +190,14 @@ private[connect] object PipelinesHandler extends Logging {
private def defineFlow(
flow: proto.PipelineCommand.DefineFlow,
transformRelationFunc: Relation => LogicalPlan,
sparkSession: SparkSession): Unit = {
sessionHolder: SessionHolder): Unit = {
val dataflowGraphId = flow.getDataflowGraphId
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId)

val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName

val flowIdentifier = GraphIdentifierManager
.parseTableIdentifier(name = flow.getFlowName, spark = sparkSession)
.parseTableIdentifier(name = flow.getFlowName, spark = sessionHolder.session)

// If the flow is not an implicit flow (i.e. one defined as part of dataset creation), then
// it must be a single-part identifier.
Expand All @@ -205,7 +211,7 @@ private[connect] object PipelinesHandler extends Logging {
new UnresolvedFlow(
identifier = flowIdentifier,
destinationIdentifier = GraphIdentifierManager
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sparkSession),
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session),
func =
FlowAnalysis.createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)),
sqlConf = flow.getSqlConfMap.asScala.toMap,
Expand All @@ -224,7 +230,7 @@ private[connect] object PipelinesHandler extends Logging {
responseObserver: StreamObserver[ExecutePlanResponse],
sessionHolder: SessionHolder): Unit = {
val dataflowGraphId = cmd.getDataflowGraphId
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId)
val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder)

// We will use this variable to store the run failure event if it occurs. This will be set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.ml.MLCache
import org.apache.spark.sql.connect.pipelines.DataflowGraphRegistry
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.pipelines.graph.PipelineUpdateContext
import org.apache.spark.sql.pipelines.graph.{GraphRegistrationContext, PipelineUpdateContext}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.{SystemClock, Utils}

Expand Down Expand Up @@ -125,6 +126,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
private lazy val pipelineExecutions =
new ConcurrentHashMap[String, PipelineUpdateContext]()

// Registry for dataflow graphs specific to this session
private lazy val dataflowGraphRegistry: DataflowGraphRegistry = new DataflowGraphRegistry()

// Handles Python process clean up for streaming queries. Initialized on first use in a query.
private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
new StreamingForeachBatchHelper.CleanerCache(this)
Expand Down Expand Up @@ -320,6 +324,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
// Stops all pipeline execution and clears the pipeline execution cache
removeAllPipelineExecutions()

// Clean up dataflow graphs
dropAllDataflowGraphs()

// if there is a server side listener, clean up related resources
if (streamingServersideListenerHolder.isServerSideListenerRegistered) {
streamingServersideListenerHolder.cleanUp()
Expand Down Expand Up @@ -486,6 +493,48 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
Option(pipelineExecutions.get(graphId))
}

private[connect] def createDataflowGraph(
defaultCatalog: String,
defaultDatabase: String,
defaultSqlConf: Map[String, String]): String = {
dataflowGraphRegistry.createDataflowGraph(defaultCatalog, defaultDatabase, defaultSqlConf)
}

/**
* Retrieves the dataflow graph for the given graph ID.
*/
private[connect] def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = {
dataflowGraphRegistry.getDataflowGraph(graphId)
}

/**
* Retrieves the dataflow graph for the given graph ID, throwing if not found.
*/
private[connect] def getDataflowGraphOrThrow(graphId: String): GraphRegistrationContext = {
dataflowGraphRegistry.getDataflowGraphOrThrow(graphId)
}

/**
* Removes the dataflow graph with the given ID.
*/
private[connect] def dropDataflowGraph(graphId: String): Unit = {
dataflowGraphRegistry.dropDataflowGraph(graphId)
}

/**
* Returns all dataflow graphs in this session.
*/
private[connect] def getAllDataflowGraphs: Seq[GraphRegistrationContext] = {
dataflowGraphRegistry.getAllDataflowGraphs
}

/**
* Removes all dataflow graphs from this session. Called during session cleanup.
*/
private[connect] def dropAllDataflowGraphs(): Unit = {
dataflowGraphRegistry.dropAllDataflowGraphs()
}

/**
* An accumulator for Python executors.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.pipelines
import java.io.{BufferedReader, InputStreamReader}
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.util.UUID
import java.util.concurrent.TimeUnit

import scala.collection.mutable.ArrayBuffer
Expand All @@ -28,6 +29,7 @@ import scala.util.Try
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.pipelines.graph.DataflowGraph
import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin}

Expand All @@ -42,6 +44,8 @@ class PythonPipelineSuite

def buildGraph(pythonText: String): DataflowGraph = {
val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
// create a unique identifier to allow identifying the session and dataflow graph
val identifier = UUID.randomUUID().toString
val pythonCode =
s"""
|from pyspark.sql import SparkSession
Expand All @@ -57,6 +61,7 @@ class PythonPipelineSuite
|spark = SparkSession.builder \\
| .remote("sc://localhost:$serverPort") \\
| .config("spark.connect.grpc.channel.timeout", "5s") \\
| .config("spark.custom.identifier", "$identifier") \\
| .create()
|
|dataflow_graph_id = create_dataflow_graph(
Expand All @@ -78,8 +83,16 @@ class PythonPipelineSuite
throw new RuntimeException(
s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}")
}
val activateSessions = SparkConnectService.sessionManager.listActiveSessions

val dataflowGraphContexts = DataflowGraphRegistry.getAllDataflowGraphs
// get the session holder by finding the session with the custom UUID set in the conf
val sessionHolder = activateSessions
.map(info => SparkConnectService.sessionManager.getIsolatedSessionIfPresent(info.key).get)
.find(_.session.conf.get("spark.custom.identifier") == identifier)
.getOrElse(throw new RuntimeException(s"Session with app name $identifier not found"))

// get all dataflow graphs from the session holder
val dataflowGraphContexts = sessionHolder.getAllDataflowGraphs
assert(dataflowGraphContexts.size == 1)

dataflowGraphContexts.head.toDataflowGraph
Expand Down
Loading