diff --git a/Makefile b/Makefile index a2bc3117..473c791d 100644 --- a/Makefile +++ b/Makefile @@ -103,6 +103,8 @@ dev-binder: .binder-image --workdir /home/main/notebooks $(BINDER_IMAGE) \ /home/main/start-notebook.sh --ip=0.0.0.0 +SPARK_MONITOR_JAR:=toree-spark-monitor-plugin-assembly-$(VERSION)$(SNAPSHOT).jar + target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR): VM_WORKDIR=/src/toree-kernel target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR): ${shell find ./*/src/main/**/*} target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR): ${shell find ./*/build.sbt} @@ -110,7 +112,14 @@ target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR): ${shell find ./project/*.scala} $ target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR): dist/toree-legal project/build.properties build.sbt $(call RUN,$(ENV_OPTS) sbt root/assembly) -build: target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR) +spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR): VM_WORKDIR=/src/toree-kernel +spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR): ${shell find ./spark-monitor-plugin/src/main/**/*} +spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR): spark-monitor-plugin/build.sbt +spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR): ${shell find ./project/*.scala} ${shell find ./project/*.sbt} +spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR): project/build.properties build.sbt + $(call RUN,$(ENV_OPTS) sbt sparkMonitorPlugin/assembly) + +build: target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR) spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR) test: VM_WORKDIR=/src/toree-kernel test: @@ -119,9 +128,10 @@ test: sbt-%: $(call RUN,$(ENV_OPTS) sbt $(subst sbt-,,$@) ) -dist/toree/lib: target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR) +dist/toree/lib: target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR) spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR) @mkdir -p dist/toree/lib @cp target/scala-$(SCALA_VERSION)/$(ASSEMBLY_JAR) dist/toree/lib/. + @cp spark-monitor-plugin/target/scala-$(SCALA_VERSION)/$(SPARK_MONITOR_JAR) dist/toree/lib/. dist/toree/bin: ${shell find ./etc/bin/*} @mkdir -p dist/toree/bin diff --git a/README.md b/README.md index 0ab068a3..a02f4b45 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,59 @@ This results in 2 packages. NOTE: `make release` uses `docker`. Please refer to `docker` installation instructions for your system. +## Building Individual Components + +### Main Toree Assembly +To build just the main Toree assembly jar (without spark-monitor-plugin): +``` +sbt assembly +``` +This creates: `target/scala-2.12/toree-assembly-.jar` + +### Spark Monitor Plugin +To build the spark-monitor-plugin as a separate jar: +``` +sbt sparkMonitorPlugin/assembly +``` +This creates: `spark-monitor-plugin/target/scala-2.12/spark-monitor-plugin-.jar` + +### Build All Components +To compile all projects including both the main assembly and spark-monitor-plugin: +``` +sbt compile +``` + +**Note**: The spark-monitor-plugin is now built as a separate jar and is not included in the main Toree assembly. + +## Using the Spark Monitor Plugin + +To enable the Spark Monitor Plugin in your Toree application, you need to specify the path to the plugin JAR when starting Toree: + +### Option 1: Command Line Parameter +```bash +# Start Toree with spark-monitor-plugin enabled +java -jar target/scala-2.12/toree-assembly-.jar --magic-url file:///path/to/spark-monitor-plugin/target/scala-2.12/spark-monitor-plugin-.jar [other-options] +``` + +### Option 2: Jupyter Kernel Installation +When installing Toree as a Jupyter kernel, you can specify the plugin: +```bash +jupyter toree install --spark_home= --kernel_name=toree_with_monitor --toree_opts="--magic-url file:///path/to/spark-monitor-plugin-.jar" +``` + +### Option 3: Configuration File +You can also specify the plugin in a configuration file and use the `--profile` option: +```json +{ + "magic_urls": ["file:///path/to/spark-monitor-plugin-.jar"] +} +``` +Then start with: `java -jar toree-assembly.jar --profile config.json` + +**Important**: +- Make sure to use the absolute path to the spark-monitor-plugin JAR file and ensure the JAR is accessible from the location where Toree is running. +- The JAR file name does not contain "toree" prefix to avoid automatic loading as an internal plugin. This allows you to control when the SparkMonitorPlugin is enabled via the `--magic-url` parameter. + Run Examples ============ To play with the example notebooks, run diff --git a/build.sbt b/build.sbt index 3fb4030b..37fa63b6 100644 --- a/build.sbt +++ b/build.sbt @@ -126,7 +126,7 @@ ThisBuild / credentials += Credentials(Path.userHome / ".ivy2" / ".credentials") lazy val root = (project in file(".")) .settings(name := "toree") .aggregate( - macros,protocol,plugins,communication,kernelApi,client,scalaInterpreter,sqlInterpreter,kernel + macros,protocol,plugins,sparkMonitorPlugin,communication,kernelApi,client,scalaInterpreter,sqlInterpreter,kernel ) .dependsOn( macros,protocol,communication,kernelApi,client,scalaInterpreter,sqlInterpreter,kernel @@ -154,6 +154,13 @@ lazy val plugins = (project in file("plugins")) .settings(name := "toree-plugins") .dependsOn(macros) +/** + * Project representing the SparkMonitor plugin for Toree. + */ +lazy val sparkMonitorPlugin = (project in file("spark-monitor-plugin")) + .settings(name := "toree-spark-monitor-plugin") + .dependsOn(macros, protocol, plugins, kernel, kernelApi) + /** * Project representing forms of communication used as input/output for the * client/kernel. diff --git a/plugins/build.sbt b/plugins/build.sbt index 88f9a494..d61c0a35 100644 --- a/plugins/build.sbt +++ b/plugins/build.sbt @@ -21,7 +21,7 @@ Test / fork := true libraryDependencies ++= Seq( Dependencies.scalaReflect.value, Dependencies.clapper, - Dependencies.slf4jApi + Dependencies.slf4jApi, ) // Test dependencies diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 08d7bf86..3b8c101a 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -84,4 +84,6 @@ object Dependencies { ) } + val py4j = "net.sf.py4j" % "py4j" % "0.10.7" % "provided" + } diff --git a/spark-monitor-plugin/build.sbt b/spark-monitor-plugin/build.sbt new file mode 100644 index 00000000..2c7d19f6 --- /dev/null +++ b/spark-monitor-plugin/build.sbt @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +Test / fork := true + +// Needed for SparkMonitor plugin +libraryDependencies ++= Dependencies.sparkAll.value +libraryDependencies ++= Seq( + Dependencies.playJson, + Dependencies.py4j +) + +// Test dependencies +libraryDependencies += Dependencies.scalaCompiler.value % "test" + +// Assembly configuration for separate jar +enablePlugins(AssemblyPlugin) + +assembly / assemblyMergeStrategy := { + case "module-info.class" => MergeStrategy.discard + case PathList("META-INF", "versions", "9", "module-info.class") => MergeStrategy.discard + case PathList("META-INF", xs @ _*) => MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) +} + +assembly / assemblyOption ~= { + _.withIncludeScala(false) +} + +assembly / test := {} \ No newline at end of file diff --git a/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/JupyterSparkMonitorListener.scala b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/JupyterSparkMonitorListener.scala new file mode 100644 index 00000000..cde7e656 --- /dev/null +++ b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/JupyterSparkMonitorListener.scala @@ -0,0 +1,781 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package org.apache.toree.plugins.sparkmonitor + +import org.apache.spark.scheduler._ +import play.api.libs.json._ +import org.apache.spark._ +import org.apache.spark.JobExecutionStatus + +import scala.collection.mutable +import scala.collection.mutable.{ HashMap, HashSet, ListBuffer } +import java.util.concurrent.{BlockingQueue, LinkedBlockingQueue} +import java.util.{TimerTask,Timer} +import org.apache.toree.comm.CommWriter +import org.apache.toree.kernel.protocol.v5.MsgData +import org.slf4j.LoggerFactory + +import scala.util.Try + +/** + * A SparkListener Implementation that forwards data to a Jupyter Kernel via comm + * + * - All data is forwarded to a jupyter kernel using comm channel. + * - The listener receives notifications of the spark application's events, through the overrided methods. + * - The received data is stored and sent as JSON to the kernel via comm. + * - Overrides methods that correspond to events in a spark Application. + * - The argument for each overrided method contains the received data for that event. (See SparkListener docs for more information.) + * - For each application, job, stage, and task there is a 'start' and an 'end' event. For executors, there are 'added' and 'removed' events + * + * @constructor called by the plugin system + * @param commWriter The comm writer to send messages through + */ +class JupyterSparkMonitorListener(getCommWriter: () => Option[CommWriter]) extends SparkListener { + + val logger = LoggerFactory.getLogger(this.getClass()) + logger.info("Started JupyterSparkMonitorListener for Jupyter Notebook") + + var onStageStatusActiveTask: TimerTask = null + val sparkTasksQueue: BlockingQueue[String] = new LinkedBlockingQueue[String]() + val sparkStageActiveTasksMaxMessages: Integer = 250 + val sparkStageActiveRate: Long = 1000L // 1s + + logger.info("Starting timer task for active stage monitoring") + startActiveStageMonitoring() + + /** Send a JSON message via comm channel. */ + def send(json: JsValue): Unit = { + val jsonString = Json.stringify(json) + getCommWriter().foreach { writer => + try { + // Create MsgData with the JSON content + val msgData = MsgData("msgtype" -> "fromscala", "msg" -> jsonString) + + // Send message directly using CommWriter + writer.writeMsg(msgData) + + } catch { + case exception: Throwable => + logger.error("Exception sending comm message: ", exception) + // Fallback: just log the message + logger.debug(s"SparkMonitor event: $jsonString") + } + } + + // If no comm writer, just log the message + if (getCommWriter().isEmpty) { + logger.debug(s"SparkMonitor event (no comm): $jsonString") + } + } + + /** Start the active stage monitoring task. */ + def startActiveStageMonitoring(): Unit = { + try { + val t = new Timer() + + if (onStageStatusActiveTask == null) { + onStageStatusActiveTask = new TimerTask { + def run() = { + onStageStatusActive() + } + } + } + t.schedule(onStageStatusActiveTask, sparkStageActiveRate, sparkStageActiveRate) + } catch { + case exception: Throwable => logger.error("Exception creating timer task: ", exception) + } + } + + /** Stop the active stage monitoring task. */ + def stopActiveStageMonitoring(): Unit = { + logger.info("Stopping active stage monitoring") + if (onStageStatusActiveTask != null) { + onStageStatusActiveTask.cancel() + } + } + + type JobId = Int + type JobGroupId = String + type StageId = Int + type StageAttemptId = Int + + //Application + @volatile var startTime = -1L + @volatile var endTime = -1L + var appId: String = "" + + //Jobs + val activeJobs = new HashMap[JobId, UIData.JobUIData] + val completedJobs = ListBuffer[UIData.JobUIData]() + val failedJobs = ListBuffer[UIData.JobUIData]() + val jobIdToData = new HashMap[JobId, UIData.JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] + + // Stages: + val pendingStages = new HashMap[StageId, StageInfo] + val activeStages = new HashMap[StageId, StageInfo] + val completedStages = ListBuffer[StageInfo]() + val skippedStages = ListBuffer[StageInfo]() + val failedStages = ListBuffer[StageInfo]() + val stageIdToData = new HashMap[(StageId, StageAttemptId), UIData.StageUIData] + val stageIdToInfo = new HashMap[StageId, StageInfo] + val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] + + var numCompletedStages = 0 + var numFailedStages = 0 + var numCompletedJobs = 0 + var numFailedJobs = 0 + + val retainedStages = 1000 + val retainedJobs = 1000 + val retainedTasks = 100000 + + @volatile + var totalNumActiveTasks = 0 + val executorCores = new HashMap[String, Int] + @volatile var totalCores: Int = 0 + @volatile var numExecutors: Int = 0 + + /** + * Called when a spark application starts. + * + * The application start time and app ID are obtained here. + */ + override def onApplicationStart(appStarted: SparkListenerApplicationStart): Unit = { + startTime = appStarted.time + appId = appStarted.appId.getOrElse("null") + logger.info("Application Started: " + appId + " ...Start Time: " + appStarted.time) + val json = Json.obj( + "msgtype" -> "sparkApplicationStart", + "startTime" -> startTime, + "appId" -> appId, + "appAttemptId" -> appStarted.appAttemptId.getOrElse[String]("null"), + "appName" -> appStarted.appName, + "sparkUser" -> appStarted.sparkUser + ) + + send(json) + } + + /** + * Called when a spark application ends. + * + * Stops the active stage monitoring task. + */ + override def onApplicationEnd(appEnded: SparkListenerApplicationEnd): Unit = { + logger.info("Application ending...End Time: " + appEnded.time) + endTime = appEnded.time + val json = Json.obj( + "msgtype" -> "sparkApplicationEnd", + "endTime" -> endTime + ) + + send(json) + stopActiveStageMonitoring() + } + + /** Converts stageInfo object to a JSON object. */ + def stageInfoToJSON(stageInfo: StageInfo): JsObject = { + val completionTime: Long = stageInfo.completionTime.getOrElse(-1) + val submissionTime: Long = stageInfo.submissionTime.getOrElse(-1) + + Json.obj( + stageInfo.stageId.toString -> Json.obj( + "attemptId" -> stageInfo.attemptNumber(), + "name" -> stageInfo.name, + "numTasks" -> stageInfo.numTasks, + "completionTime" -> completionTime, + "submissionTime" -> submissionTime + ) + ) + } + + /** + * Called when a job starts. + * + * The jobStart object contains the list of planned stages. They are stored for tracking skipped stages. + * The total number of tasks is also estimated from the list of planned stages, + */ + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { + + val jobGroup = for ( + props <- Option(jobStart.properties); + group <- Option(props.getProperty("spark.jobGroup.id")) + ) yield group + + val jobData: UIData.JobUIData = + new UIData.JobUIData( + jobId = jobStart.jobId, + submissionTime = Option(jobStart.time).filter(_ >= 0), + stageIds = jobStart.stageIds, + jobGroup = jobGroup, + status = JobExecutionStatus.RUNNING) + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) + jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) + + // Merge all stage info objects into one JSON object + val stageinfojson = jobStart.stageInfos.foldLeft(Json.obj()) { (acc, stageInfo) => + acc ++ stageInfoToJSON(stageInfo) + } + + jobData.numTasks = { + val allStages = jobStart.stageInfos + val missingStages = allStages.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + jobIdToData(jobStart.jobId) = jobData + activeJobs(jobStart.jobId) = jobData + for (stageId <- jobStart.stageIds) { + stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) + } + // If there's no information for a stage, store the StageInfo received from the scheduler + // so that we can display stage descriptions for pending stages: + for (stageInfo <- jobStart.stageInfos) { + stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) + stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptNumber()), new UIData.StageUIData) + } + val name = jobStart.properties.getProperty("callSite.short", "null") + val json = Json.obj( + "msgtype" -> "sparkJobStart", + "jobGroup" -> jobGroup.getOrElse[String]("null"), + "jobId" -> jobStart.jobId, + "status" -> "RUNNING", + "submissionTime" -> Option(jobStart.time).filter(_ >= 0), + "stageIds" -> jobStart.stageIds, + "stageInfos" -> stageinfojson, + "numTasks" -> jobData.numTasks, + "totalCores" -> totalCores, + "appId" -> appId, + "numExecutors" -> numExecutors, + "name" -> name + ) + logger.info("Job Start: " + jobStart.jobId) + logger.debug(Json.prettyPrint(json)) + send(json) + } + + /** Called when a job ends. */ + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { + logger.info("Job completed for unknown job: " + jobEnd.jobId) + new UIData.JobUIData(jobId = jobEnd.jobId) + } + jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) + var status = "null" + jobData.stageIds.foreach(pendingStages.remove) + jobEnd.jobResult match { + case JobSucceeded => + completedJobs += jobData + trimJobsIfNecessary(completedJobs) + jobData.status = JobExecutionStatus.SUCCEEDED + status = "COMPLETED" + numCompletedJobs += 1 + case _ => + failedJobs += jobData + trimJobsIfNecessary(failedJobs) + jobData.status = JobExecutionStatus.FAILED + numFailedJobs += 1 + status = "FAILED" + } + for (stageId <- jobData.stageIds) { + stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => + jobsUsingStage.remove(jobEnd.jobId) + if (jobsUsingStage.isEmpty) { + stageIdToActiveJobIds.remove(stageId) + } + stageIdToInfo.get(stageId).foreach { stageInfo => + if (stageInfo.submissionTime.isEmpty) { + // if this stage is pending, it won't complete, so mark it as "skipped": + skippedStages += stageInfo + trimStagesIfNecessary(skippedStages) + jobData.numSkippedStages += 1 + jobData.numSkippedTasks += stageInfo.numTasks + } + } + } + } + + val json = Json.obj( + "msgtype" -> "sparkJobEnd", + "jobId" -> jobEnd.jobId, + "status" -> status, + "completionTime" -> jobData.completionTime + ) + + logger.info("Job End: " + jobEnd.jobId) + logger.debug(Json.prettyPrint(json)) + + send(json) + } + + /** Called when a stage is completed. */ + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { + val stage = stageCompleted.stageInfo + stageIdToInfo(stage.stageId) = stage + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptNumber()), { + logger.info("Stage completed for unknown stage " + stage.stageId) + new UIData.StageUIData + }) + var status = "UNKNOWN" + activeStages.remove(stage.stageId) + if (stage.failureReason.isEmpty) { + completedStages += stage + numCompletedStages += 1 + trimStagesIfNecessary(completedStages) + status = "COMPLETED" + } else { + failedStages += stage + numFailedStages += 1 + trimStagesIfNecessary(failedStages) + status = "FAILED" + } + + val jobIds = stageIdToActiveJobIds.get(stage.stageId) + for ( + activeJobsDependentOnStage <- jobIds; + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages -= 1 + if (stage.failureReason.isEmpty) { + if (stage.submissionTime.isDefined) { + jobData.completedStageIndices.add(stage.stageId) + } + } else { + jobData.numFailedStages += 1 + } + } + val completionTime: Long = stage.completionTime.getOrElse(-1) + val submissionTime: Long = stage.submissionTime.getOrElse(-1) + val json = Json.obj( + "msgtype" -> "sparkStageCompleted", + "stageId" -> stage.stageId, + "stageAttemptId" -> stage.attemptNumber(), + "completionTime" -> completionTime, + "submissionTime" -> submissionTime, + "numTasks" -> stage.numTasks, + "numFailedTasks" -> stageData.numFailedTasks, + "numCompletedTasks" -> stageData.numCompletedTasks, + "status" -> status, + "jobIds" -> jobIds + ) + + logger.info("Stage Completed: " + stage.stageId) + logger.debug(Json.prettyPrint(json)) + send(json) + } + + /** Called when a stage is submitted for execution. */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stage = stageSubmitted.stageInfo + activeStages(stage.stageId) = stage + pendingStages.remove(stage.stageId) + stageIdToInfo(stage.stageId) = stage + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptNumber()), new UIData.StageUIData) + stageData.description = Option(stageSubmitted.properties).flatMap { + p => Option(p.getProperty("spark.job.description")) + } + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages += 1 + // If a stage retries again, it should be removed from completedStageIndices set + jobData.completedStageIndices.remove(stage.stageId) + } + val activeJobsDependentOnStage = stageIdToActiveJobIds.get(stage.stageId) + val jobIds = activeJobsDependentOnStage + val submissionTime: Long = stage.submissionTime.getOrElse(-1) + val json = Json.obj( + "msgtype" -> "sparkStageSubmitted", + "stageId" -> stage.stageId, + "stageAttemptId" -> stage.attemptNumber(), + "name" -> stage.name, + "numTasks" -> stage.numTasks, + "parentIds" -> stage.parentIds, + "submissionTime" -> submissionTime, + "jobIds" -> jobIds + ) + logger.info("Stage Submitted: " + stage.stageId) + logger.debug(Json.prettyPrint(json)) + send(json) + } + + /** Called when scheduled stage tasks update was requested */ + def onStageStatusActive(): Unit = { + // Update on status of active stages + for ((stageId, stageInfo) <- activeStages) { + val stageData = stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptNumber()), new UIData.StageUIData) + val jobIds = stageIdToActiveJobIds.get(stageInfo.stageId) + + val json = Json.obj( + "msgtype" -> "sparkStageActive", + "stageId" -> stageInfo.stageId, + "stageAttemptId" -> stageInfo.attemptNumber(), + "name" -> stageInfo.name, + "parentIds" -> stageInfo.parentIds, + "numTasks" -> stageInfo.numTasks, + "numActiveTasks" -> stageData.numActiveTasks, + "numFailedTasks" -> stageData.numFailedTasks, + "numCompletedTasks" -> stageData.numCompletedTasks, + "jobIds" -> jobIds + ) + + logger.info("Stage Update: " + stageInfo.stageId) + logger.debug(Json.prettyPrint(json)) + send(json) + } + + // Emit sparkStageActiveTasksMaxMessages spark tasks details from queue to frontend + var count: Integer = 0 + while (sparkTasksQueue != null && !sparkTasksQueue.isEmpty() && count <= sparkStageActiveTasksMaxMessages) { + count = count + 1 + val jsonString = sparkTasksQueue.take() + // Already stringified JSON, parse and send + Try(Json.parse(jsonString)).foreach(send) + } + + if (count > 0) { + logger.info("Stage Tasks details updated: " + count) + } + } + + /** Called when a task is started. */ + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + val taskInfo = taskStart.taskInfo + if (taskInfo != null) { + val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { + logger.info("Task start for unknown stage " + taskStart.stageId) + new UIData.StageUIData + }) + stageData.numActiveTasks += 1 + } + var jobjson = Json.obj("jobdata" -> "taskstart") + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks += 1 + val jobjson = Json.obj( + "jobdata" -> Json.obj( + "jobId" -> jobData.jobId, + "numTasks" -> jobData.numTasks, + "numActiveTasks" -> jobData.numActiveTasks, + "numCompletedTasks" -> jobData.numCompletedTasks, + "numSkippedTasks" -> jobData.numSkippedTasks, + "numFailedTasks" -> jobData.numFailedTasks, + "reasonToNumKilled" -> jobData.reasonToNumKilled, + "numActiveStages" -> jobData.numActiveStages, + "numSkippedStages" -> jobData.numSkippedStages, + "numFailedStages" -> jobData.numFailedStages + ) + ) + } + val json = Json.obj( + "msgtype" -> "sparkTaskStart", + "launchTime" -> taskInfo.launchTime, + "taskId" -> taskInfo.taskId, + "stageId" -> taskStart.stageId, + "stageAttemptId" -> taskStart.stageAttemptId, + "index" -> taskInfo.index, + "attemptNumber" -> taskInfo.attemptNumber, + "executorId" -> taskInfo.executorId, + "host" -> taskInfo.host, + "status" -> taskInfo.status, + "speculative" -> taskInfo.speculative + ) + + logger.info("Task Start: " + taskInfo.taskId) + logger.debug(Json.prettyPrint(json)) + + // Buffer the message for periodic flushing + sparkTasksQueue.put(Json.stringify(json)) + } + + /** Called when a task is ended. */ + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val info = taskEnd.taskInfo + // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task + // completion event is for. Let's just drop it here. This means we might have some speculation + // tasks on the web ui that's never marked as complete. + var errorMessage: Option[String] = None + if (info != null && taskEnd.stageAttemptId != -1) { + val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { + logger.info("Task end for unknown stage " + taskEnd.stageId) + new UIData.StageUIData + }) + stageData.numActiveTasks -= 1 + errorMessage = taskEnd.reason match { + case org.apache.spark.Success => + stageData.completedIndices.add(info.index) + stageData.numCompletedTasks += 1 + None + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates + stageData.numFailedTasks += 1 + Some(e.toErrorString) + case e: TaskFailedReason => // All other failure cases + stageData.numFailedTasks += 1 + Some(e.toErrorString) + } + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks -= 1 + taskEnd.reason match { + case Success => + jobData.numCompletedTasks += 1 + case _ => + jobData.numFailedTasks += 1 + } + } + } + + val totalExecutionTime = info.finishTime - info.launchTime + def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 + var metricsOpt = Option(taskEnd.taskMetrics) + val shuffleReadTime = metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L) + val shuffleReadTimeProportion = toProportion(shuffleReadTime) + val shuffleWriteTime = (metricsOpt.map(_.shuffleWriteMetrics.writeTime).getOrElse(0L) / 1e6).toLong + val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) + val serializationTimeProportion = toProportion(serializationTime) + val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) + val deserializationTimeProportion = toProportion(deserializationTime) + val gettingResultTime = if (info.gettingResult) { + if (info.finished) { + info.finishTime - info.gettingResultTime + } else { + 0L //currentTime - info.gettingResultTime + } + } else { + 0L + } + val gettingResultTimeProportion = toProportion(gettingResultTime) + val executorOverhead = serializationTime + deserializationTime + val executorRunTime = metricsOpt.map(_.executorRunTime).getOrElse(totalExecutionTime - executorOverhead - gettingResultTime) + val schedulerDelay = math.max(0, totalExecutionTime - executorRunTime - executorOverhead - gettingResultTime) + val schedulerDelayProportion = toProportion(schedulerDelay) + val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = + math.max(100 - schedulerDelayProportion - shuffleReadTimeProportion - + shuffleWriteTimeProportion - serializationTimeProportion - + deserializationTimeProportion - gettingResultTimeProportion, 0) + + val schedulerDelayProportionPos = 0 + val deserializationTimeProportionPos = schedulerDelayProportionPos + schedulerDelayProportion + val shuffleReadTimeProportionPos = deserializationTimeProportionPos + deserializationTimeProportion + val executorRuntimeProportionPos = shuffleReadTimeProportionPos + shuffleReadTimeProportion + val shuffleWriteTimeProportionPos = executorRuntimeProportionPos + executorComputingTimeProportion + val serializationTimeProportionPos = shuffleWriteTimeProportionPos + shuffleWriteTimeProportion + val gettingResultTimeProportionPos = serializationTimeProportionPos + serializationTimeProportion + + val jsonMetrics = if (metricsOpt.isDefined) { + Json.obj( + "shuffleReadTime" -> shuffleReadTime, + "shuffleWriteTime" -> shuffleWriteTime, + "serializationTime" -> serializationTime, + "deserializationTime" -> deserializationTime, + "gettingResultTime" -> gettingResultTime, + "executorComputingTime" -> executorComputingTime, + "schedulerDelay" -> schedulerDelay, + "shuffleReadTimeProportion" -> shuffleReadTimeProportion, + "shuffleWriteTimeProportion" -> shuffleWriteTimeProportion, + "serializationTimeProportion" -> serializationTimeProportion, + "deserializationTimeProportion" -> deserializationTimeProportion, + "gettingResultTimeProportion" -> gettingResultTimeProportion, + "executorComputingTimeProportion" -> executorComputingTimeProportion, + "schedulerDelayProportion" -> schedulerDelayProportion, + "shuffleReadTimeProportionPos" -> shuffleReadTimeProportionPos, + "shuffleWriteTimeProportionPos" -> shuffleWriteTimeProportionPos, + "serializationTimeProportionPos" -> serializationTimeProportionPos, + "deserializationTimeProportionPos" -> deserializationTimeProportionPos, + "gettingResultTimeProportionPos" -> gettingResultTimeProportionPos, + "executorComputingTimeProportionPos" -> executorRuntimeProportionPos, + "schedulerDelayProportionPos" -> schedulerDelayProportionPos, + "resultSize" -> metricsOpt.get.resultSize, + "jvmGCTime" -> metricsOpt.get.jvmGCTime, + "memoryBytesSpilled" -> metricsOpt.get.memoryBytesSpilled, + "diskBytesSpilled" -> metricsOpt.get.diskBytesSpilled, + "peakExecutionMemory" -> metricsOpt.get.peakExecutionMemory, + "test" -> info.gettingResultTime + ) + } else { + Json.obj() + } + + val json = Json.obj( + "msgtype" -> "sparkTaskEnd", + "launchTime" -> info.launchTime, + "finishTime" -> info.finishTime, + "taskId" -> info.taskId, + "stageId" -> taskEnd.stageId, + "taskType" -> taskEnd.taskType, + "stageAttemptId" -> taskEnd.stageAttemptId, + "index" -> info.index, + "attemptNumber" -> info.attemptNumber, + "executorId" -> info.executorId, + "host" -> info.host, + "status" -> info.status, + "speculative" -> info.speculative, + "errorMessage" -> errorMessage, + "metrics" -> jsonMetrics + ) + + logger.info("Task Ended: " + info.taskId) + logger.debug(Json.prettyPrint(json)) + + // Buffer the message for periodic flushing + sparkTasksQueue.put(Json.stringify(json)) + } + + /** If stored stages data is too large, remove and garbage collect old stages */ + private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { + if (stages.size > retainedStages) { + val toRemove = calculateNumberToRemove(stages.size, retainedStages) + stages.take(toRemove).foreach { s => + stageIdToData.remove((s.stageId, s.attemptNumber())) + stageIdToInfo.remove(s.stageId) + } + stages.trimStart(toRemove) + } + } + + /** If stored jobs data is too large, remove and garbage collect old jobs */ + private def trimJobsIfNecessary(jobs: ListBuffer[UIData.JobUIData]) = synchronized { + if (jobs.size > retainedJobs) { + val toRemove = calculateNumberToRemove(jobs.size, retainedJobs) + jobs.take(toRemove).foreach { job => + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } + } + jobs.trimStart(toRemove) + } + } + + /** Calculate number of items to remove from stored data. */ + private def calculateNumberToRemove(dataSize: Int, retainedSize: Int): Int = { + math.max(retainedSize / 10, dataSize - retainedSize) + } + + /** Called when an executor is added. */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { + executorCores(executorAdded.executorId) = executorAdded.executorInfo.totalCores + totalCores += executorAdded.executorInfo.totalCores + numExecutors += 1 + val json = Json.obj( + "msgtype" -> "sparkExecutorAdded", + "executorId" -> executorAdded.executorId, + "time" -> executorAdded.time, + "host" -> executorAdded.executorInfo.executorHost, + "numCores" -> executorAdded.executorInfo.totalCores, + "totalCores" -> totalCores // Sending this as browser data can be lost during reloads + ) + + logger.info("Executor Added: " + executorAdded.executorId) + logger.debug(Json.prettyPrint(json)) + send(json) + } + + /** Called when an executor is removed. */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = synchronized { + totalCores -= executorCores.getOrElse(executorRemoved.executorId, 0) + numExecutors -= 1 + val json = Json.obj( + "msgtype" -> "sparkExecutorRemoved", + "executorId" -> executorRemoved.executorId, + "time" -> executorRemoved.time, + "totalCores" -> totalCores // Sending this as browser data can be lost during reloads + ) + + logger.info("Executor Removed: " + executorRemoved.executorId) + logger.debug(Json.prettyPrint(json)) + + send(json) + } +} + +/** Data Structures for storing received from listener events. */ +object UIData { + + /** + * Data about a job. + * + * This is stored to track aggregated valus such as number of stages and tasks, and to track skipped and failed stages + */ + class JobUIData( + var jobId: Int = -1, + var submissionTime: Option[Long] = None, + var completionTime: Option[Long] = None, + var stageIds: Seq[Int] = Seq.empty, + var jobGroup: Option[String] = None, + var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, + var numTasks: Int = 0, + var numActiveTasks: Int = 0, + var numCompletedTasks: Int = 0, + var numSkippedTasks: Int = 0, + var numFailedTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, + var numActiveStages: Int = 0, + // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: + var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](), + var numSkippedStages: Int = 0, + var numFailedStages: Int = 0) + + /** + * Data about a stage. + * + * This is stored to track aggregated valus such as number of tasks. + */ + class StageUIData { + var numActiveTasks: Int = _ + var numCompletedTasks: Int = _ + var completedIndices = new HashSet[Int]() + var numFailedTasks: Int = _ + var description: Option[String] = None + } + + /** + * Data about an executor. + * + * When an executor is removed, its number of cores is not available, so it is looked up here. + */ + class ExecutorData { + var numCores: Int = _ + var executorId: String = _ + var timeAdded: Long = _ + var timeRemoved: Long = _ + var executorHost: String = _ + } +} diff --git a/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/README.md b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/README.md new file mode 100644 index 00000000..d638f13b --- /dev/null +++ b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/README.md @@ -0,0 +1,227 @@ +# SparkMonitor Plugin + +The SparkMonitor plugin provides real-time monitoring of Apache Spark jobs, stages, and tasks through a Jupyter comm channel in Apache Toree. + +## Features + +- **SparkListener Integration**: Automatically registers a JupyterSparkMonitorListener to the SparkContext when Spark becomes ready +- **Communication Channel**: Provides a "SparkMonitor" comm target for bi-directional communication with clients +- **Real-time Updates**: Sends comprehensive real-time notifications about: + - Application start/end events + - Job start/end events with detailed information + - Stage submission/completion events with metrics + - Task start/end events with performance data + - Executor addition/removal events +- **Error Handling**: Robust error handling to prevent plugin failures from affecting the kernel +- **Performance Monitoring**: Detailed task metrics including execution time, shuffle data, and memory usage + +## Architecture + +The plugin consists of two main components: + +### SparkMonitorPlugin +- Extends the Toree `Plugin` class +- Registers the "SparkMonitor" comm target during initialization +- Listens for the "sparkReady" event to register the SparkListener +- Manages the communication channel lifecycle + +### JupyterSparkMonitorListener +- Extends Spark's `SparkListener` class +- Monitors comprehensive Spark events and sends updates through the comm channel +- Handles communication failures gracefully +- Provides detailed event tracking with JSON-formatted messages +- Includes active stage monitoring with periodic updates + +## Usage + +### Plugin Registration + +The plugin is automatically discovered and loaded by Toree's plugin system. No manual registration is required. + +### Client Communication + +Clients can connect to the SparkMonitor by opening a comm with the target name "SparkMonitor": + +```python +# Example using Jupyter notebook +from IPython.display import display +import ipywidgets as widgets +from traitlets import Unicode +import json + +# Create a comm to connect to SparkMonitor +comm = Comm(target_name='SparkMonitor') + +def handle_spark_event(msg): + data = msg['content']['data'] + event_type = data.get('event', 'unknown') + print(f"Spark Event: {event_type}") + print(f"Details: {data}") + +comm.on_msg(handle_spark_event) +``` + +### Event Types + +The plugin sends the following types of events: + +#### Application Events +- **sparkApplicationStart**: Fired when a Spark application starts + - `startTime`: Application start timestamp + - `appId`: Application identifier + - `appAttemptId`: Application attempt identifier + - `appName`: Application name + - `sparkUser`: Spark user + +- **sparkApplicationEnd**: Fired when a Spark application ends + - `endTime`: Application end timestamp + +#### Job Events +- **sparkJobStart**: Fired when a Spark job starts + - `jobId`: Job identifier + - `jobGroup`: Job group identifier + - `status`: Job status (RUNNING) + - `submissionTime`: Job submission timestamp + - `stageIds`: Array of stage IDs + - `stageInfos`: Detailed stage information + - `numTasks`: Total number of tasks + - `totalCores`: Total available cores + - `numExecutors`: Number of executors + - `name`: Job name/description + +- **sparkJobEnd**: Fired when a Spark job ends + - `jobId`: Job identifier + - `status`: Job completion status (COMPLETED/FAILED) + - `completionTime`: Job completion timestamp + +#### Stage Events +- **sparkStageSubmitted**: Fired when a stage is submitted + - `stageId`: Stage identifier + - `stageAttemptId`: Stage attempt identifier + - `name`: Stage name + - `numTasks`: Number of tasks in the stage + - `parentIds`: Parent stage IDs + - `submissionTime`: Submission timestamp + - `jobIds`: Associated job IDs + +- **sparkStageCompleted**: Fired when a stage completes + - `stageId`: Stage identifier + - `stageAttemptId`: Stage attempt identifier + - `completionTime`: Completion timestamp + - `submissionTime`: Submission timestamp + - `numTasks`: Total number of tasks + - `numFailedTasks`: Number of failed tasks + - `numCompletedTasks`: Number of completed tasks + - `status`: Stage status (COMPLETED/FAILED) + - `jobIds`: Associated job IDs + +- **sparkStageActive**: Periodic updates for active stages + - `stageId`: Stage identifier + - `stageAttemptId`: Stage attempt identifier + - `name`: Stage name + - `parentIds`: Parent stage IDs + - `numTasks`: Total number of tasks + - `numActiveTasks`: Number of currently active tasks + - `numFailedTasks`: Number of failed tasks + - `numCompletedTasks`: Number of completed tasks + - `jobIds`: Associated job IDs + +#### Task Events +- **sparkTaskStart**: Fired when a task starts + - `launchTime`: Task launch timestamp + - `taskId`: Task identifier + - `stageId`: Parent stage identifier + - `stageAttemptId`: Stage attempt identifier + - `index`: Task index + - `attemptNumber`: Task attempt number + - `executorId`: Executor identifier + - `host`: Host name + - `status`: Task status + - `speculative`: Whether task is speculative + +- **sparkTaskEnd**: Fired when a task ends + - `launchTime`: Task launch timestamp + - `finishTime`: Task finish timestamp + - `taskId`: Task identifier + - `stageId`: Parent stage identifier + - `taskType`: Task type + - `stageAttemptId`: Stage attempt identifier + - `index`: Task index + - `attemptNumber`: Task attempt number + - `executorId`: Executor identifier + - `host`: Host name + - `status`: Task status + - `speculative`: Whether task is speculative + - `errorMessage`: Error message (if failed) + - `metrics`: Detailed task metrics (execution time, shuffle data, memory usage, etc.) + +#### Executor Events +- **sparkExecutorAdded**: Fired when an executor is added + - `executorId`: Executor identifier + - `time`: Addition timestamp + - `host`: Host name + - `numCores`: Number of cores + - `totalCores`: Total cores across all executors + +- **sparkExecutorRemoved**: Fired when an executor is removed + - `executorId`: Executor identifier + - `time`: Removal timestamp + - `totalCores`: Remaining total cores + +## Example Spark Code + +To see the SparkMonitor in action, run some Spark operations: + +```scala +// Create an RDD and perform some operations +val rdd = sc.parallelize(1 to 1000, 10) +val result = rdd.map(_ * 2).filter(_ > 100).collect() + +// The SparkMonitor will send notifications about: +// - Job start/end +// - Stage submissions/completions +// - Task executions +``` + +## Configuration + +The plugin doesn't require any specific configuration. It automatically: +- Registers the "SparkMonitor" comm target during kernel initialization +- Registers the JupyterSparkMonitorListener when Spark becomes ready +- Handles comm connection lifecycle automatically +- Provides periodic active stage monitoring (every 1 second) + +### Configurable Parameters + +The following parameters can be adjusted by modifying the JupyterSparkMonitorListener class: +- `sparkStageActiveTasksMaxMessages`: Maximum number of task messages to buffer (default: 250) +- `sparkStageActiveRate`: Rate of active stage monitoring in milliseconds (default: 1000ms) + +## Error Handling + +The plugin includes comprehensive error handling: +- Graceful handling of SparkContext unavailability +- Safe communication channel operations +- Logging of errors without affecting kernel stability + +## Development + +To extend the plugin: + +1. Add new event handlers to `SparkMonitorListener` +2. Modify the `sendUpdate` method to include additional data +3. Update the comm message handlers in `SparkMonitorPlugin` + +## Testing + +Run the plugin tests using: + +```bash +sbt "project plugins" test +``` + +The test suite includes: +- Plugin initialization tests +- SparkListener registration tests +- Communication channel tests +- Error handling tests \ No newline at end of file diff --git a/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/SparkMonitorPlugin.scala b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/SparkMonitorPlugin.scala new file mode 100644 index 00000000..d2cbe308 --- /dev/null +++ b/spark-monitor-plugin/src/main/scala/org/apache/toree/plugins/sparkmonitor/SparkMonitorPlugin.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package org.apache.toree.plugins.sparkmonitor + +import org.apache.toree.plugins.{AllInterpretersReady, Plugin, PluginManager, SparkReady} +import org.apache.toree.plugins.annotations.{Event, Init} +import org.apache.toree.plugins.dependencies.Dependency +import org.apache.toree.kernel.api.KernelLike +import org.apache.toree.kernel.api.Kernel +import org.apache.toree.comm.{CommRegistrar, CommWriter} +import org.apache.toree.kernel.protocol.v5.{MsgData, UUID} +import org.apache.toree.utils.ScheduledTaskManager +import org.apache.spark.SparkContext +import org.apache.log4j.Logger + +import scala.util.Try +import scala.reflect.runtime.universe +import java.lang.reflect.Field + +/** + * Plugin that registers a JupyterSparkMonitorListener to SparkContext after kernel started + * and provides communication through a "SparkMonitor" comm target. + * + * This plugin uses proper typing where possible and falls back to reflection only when necessary + * to maintain compatibility across different kernel implementations. + */ +class SparkMonitorPlugin extends Plugin { + + private val logger = Logger.getLogger(this.getClass.getName) + private var sparkMonitorListener: Option[JupyterSparkMonitorListener] = None + private var commWriter: Option[CommWriter] = None + private val taskManager = new ScheduledTaskManager() + private var sparkContextMonitorTaskId: Option[String] = None + private var currentKernel: Option[KernelLike] = None + private var pluginManager: Option[PluginManager] = None + + // Communication target name - extracted as constant for better maintainability + private val COMM_TARGET_NAME = "SparkMonitor" + + /** + * Initialize the plugin by registering the SparkMonitor comm target. + * Uses proper typing with safe casting and comprehensive error handling. + * + * @param kernel The kernel instance implementing KernelLike interface + */ + @Init + def initializePlugin(kernel: KernelLike): Unit = { + logger.info(s"Initializing SparkMonitor plugin with comm target: $COMM_TARGET_NAME") + + Try { + // Cast to concrete Kernel type to access comm property. This plugin need to use toree-kernel module + val concreteKernel = kernel.asInstanceOf[Kernel] + val commManager = concreteKernel.comm + + // Register comm target - now we can use proper types + val commRegistrar = commManager.register(COMM_TARGET_NAME) + setupCommHandlers(commRegistrar) + + // Start background process to monitor SparkContext creation + startSparkContextMonitoring() + + logger.info("SparkMonitor plugin initialized successfully") + }.recover { + case ex => logger.error(s"Failed to initialize SparkMonitor plugin: ${ex.getMessage}", ex) + } + } + + /** + * Initialize the plugin manager reference. + * + * @param manager The plugin manager instance + */ + @Init + def initializePluginManager(manager: PluginManager): Unit = { + logger.info("Initializing PluginManager reference") + pluginManager = Some(manager) + } + + /** + * Sets up all communication handlers for the registered comm target. + */ + private def setupCommHandlers(commRegistrar: CommRegistrar): Unit = { + // Add open handler + commRegistrar.addOpenHandler { (commWriter: CommWriter, commId: UUID, targetName: String, data: MsgData) => + Try { + logger.info(s"SparkMonitor comm opened - ID: $commId, Target: $targetName") + this.commWriter = Some(commWriter) + + // Send initial connection message + val message = MsgData("msgtype" -> "commopen") + commWriter.writeMsg(message) + }.recover { + case ex => logger.warn(s"Error in comm open handler: ${ex.getMessage}", ex) + } + } + + // Add message handler + commRegistrar.addMsgHandler { (commWriter: CommWriter, commId: UUID, data: MsgData) => + logger.debug(s"SparkMonitor received message from comm $commId: $data") + // Handle incoming messages from client if needed + // This can be extended for bidirectional communication + } + + // Add close handler + commRegistrar.addCloseHandler { (commWriter: CommWriter, commId: UUID, data: MsgData) => + logger.info(s"SparkMonitor comm closed - ID: $commId") + this.commWriter = None + } + } + + /** + * Starts a background task to monitor SparkContext creation. + * This task runs periodically to check if SparkContext becomes available using reflection. + */ + private def startSparkContextMonitoring(): Unit = { + logger.info("Starting SparkContext monitoring background task") + + val taskId = taskManager.addTask( + executionDelay = 1000, // Start checking after 1 second + timeInterval = 2000, // Check every 2 seconds + task = { + try { + logger.info("Task execution started - this should appear every 2 seconds") + checkSparkContextAndNotify() + logger.info("Task execution completed") + } catch { + case ex: Exception => + logger.error("Task execution failed", ex) + } + } + ) + + sparkContextMonitorTaskId = Some(taskId) + logger.debug(s"SparkContext monitoring task started with ID: $taskId") + } + + /** + * Checks if SparkContext is available using reflection to access private activeContext field. + * Once SparkContext is found, stops the monitoring task and fires SparkReady event. + * + * Uses reflection to safely access SparkContext.activeContext without triggering instantiation. + */ + private def checkSparkContextAndNotify(): Unit = { + Try { + logger.debug("checkSparkContextAndNotify is running") + getActiveSparkContext() match { + case Some(sparkContext) if !sparkContext.isStopped => + logger.info("SparkContext detected! Firing SparkReady event.") + + // Stop the monitoring task since SparkContext is now available + stopSparkContextMonitoring() + + // Fire SparkReady event through plugin manager to notify all plugins + fireSparkReadyEvent() + + case Some(sparkContext) => + logger.debug("SparkContext exists but is stopped, continuing to monitor...") + + case None => + logger.debug("No SparkContext found, continuing to monitor...") + } + }.recover { + case ex => + logger.debug(s"Error checking SparkContext availability: ${ex.getMessage}") + } + } + + /** + * Uses reflection to safely access the private activeContext field from SparkContext. + * This approach doesn't trigger SparkContext instantiation. + */ + private def getActiveSparkContext(): Option[SparkContext] = { + Try { + val runtimeMirror = universe.runtimeMirror(getClass.getClassLoader) + val moduleSymbol = runtimeMirror.staticModule("org.apache.spark.SparkContext") + val moduleMirror = runtimeMirror.reflectModule(moduleSymbol) + val sparkContext = moduleMirror.instance + + val activeContextField: Field = sparkContext.getClass().getDeclaredField("org$apache$spark$SparkContext$$activeContext") + activeContextField.setAccessible(true) + + val activeContextRef = activeContextField.get(sparkContext).asInstanceOf[java.util.concurrent.atomic.AtomicReference[SparkContext]] + Option(activeContextRef.get()) + }.recover { + case ex => + logger.error(s"Failed to access activeContext field via reflection: ${ex.getMessage} \n $ex") + None + }.getOrElse(None) + } + + /** + * Fires a SparkReady event through the plugin manager to notify all listening plugins. + */ + private def fireSparkReadyEvent(): Unit = { + Try { + // Create a KernelLike dependency to pass along with the event + pluginManager match { + case Some(manager) => + manager.fireEvent(SparkReady) + logger.info("SparkReady event fired to all plugins") + case _ => + logger.warn("Cannot fire SparkReady event: kernel or plugin manager not available") + } + }.recover { + case ex => logger.warn(s"Failed to fire SparkReady event: ${ex.getMessage}") + } + } + + /** + * Stops the SparkContext monitoring background task. + */ + private def stopSparkContextMonitoring(): Unit = { + sparkContextMonitorTaskId.foreach { taskId => + if (taskManager.removeTask(taskId)) { + logger.info(s"SparkContext monitoring task stopped (ID: $taskId)") + } + sparkContextMonitorTaskId = None + } + } + + /** + * Handle the SparkReady event to register the JupyterSparkMonitorListener. + * This method is called when Spark becomes available in the kernel. + * Uses direct access to SparkContext when possible. + */ + @Event(name = "sparkReady") + def onReady(kernel: KernelLike): Unit = { + logger.info("SparkReady event received, registering JupyterSparkMonitorListener") + + // Stop monitoring task if still running since SparkContext is ready + stopSparkContextMonitoring() + + Try { + val sparkContext = kernel.sparkContext + // Pass a callback function that always gets the current commWriter + val listener = new JupyterSparkMonitorListener(() => commWriter) + + sparkContext.addSparkListener(listener) + sparkMonitorListener = Some(listener) + + logger.info("JupyterSparkMonitorListener registered successfully") + // notifySparkListenerRegistration() + }.recover { + case ex => logger.error(s"Failed to register JupyterSparkMonitorListener: ${ex.getMessage}", ex) + } + } + + /** + * Notifies clients that the Spark listener has been registered. + */ + private def notifySparkListenerRegistration(): Unit = { + commWriter.foreach { writer => + Try { + val message = MsgData("text/plain" -> "JupyterSparkMonitorListener registered") + writer.writeMsg(message) + }.recover { + case ex => logger.warn(s"Failed to send SparkListener registration notification: ${ex.getMessage}") + } + } + } +} \ No newline at end of file