diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml
index 0a75563a..69bb51a5 100644
--- a/.github/workflows/ray_nightly_test.yml
+++ b/.github/workflows/ray_nightly_test.yml
@@ -32,7 +32,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
python-version: [3.9, 3.10.14]
- spark-version: [3.3.2, 3.4.0, 3.5.0]
+ spark-version: [4.0.0]
runs-on: ${{ matrix.os }}
diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml
index 999f184d..8b7f6067 100644
--- a/.github/workflows/raydp.yml
+++ b/.github/workflows/raydp.yml
@@ -33,7 +33,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: [3.9, 3.10.14]
- spark-version: [3.3.2, 3.4.0, 3.5.0]
+ spark-version: [4.0.0]
ray-version: [2.34.0, 2.40.0]
runs-on: ${{ matrix.os }}
@@ -74,8 +74,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install wheel
- pip install "numpy<1.24" "click<8.3.0"
- pip install "pydantic<2.0"
+ pip install "numpy<1.24"
+ pip install "pydantic<2.0" "click<8.3.0"
SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])')
if [ "$(uname -s)" == "Linux" ]
then
diff --git a/core/pom.xml b/core/pom.xml
index 5e1185d9..6662cd5e 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -18,6 +18,7 @@
3.3.0
3.4.0
3.5.0
+ 4.0.0
1.1.10.4
4.1.94.Final
1.10.0
@@ -29,9 +30,9 @@
UTF-8
1.8
1.8
- 2.12.15
- 2.15.0
- 2.12
+ 2.13.12
+ 2.18.2
+ 2.13
5.10.1
@@ -151,16 +152,19 @@
com.fasterxml.jackson.core
jackson-core
${jackson.version}
+ provided
com.fasterxml.jackson.core
jackson-databind
${jackson.version}
+ provided
com.fasterxml.jackson.core
jackson-annotations
${jackson.version}
+ provided
@@ -168,6 +172,7 @@
com.fasterxml.jackson.module
jackson-module-scala_${scala.binary.version}
${jackson.version}
+ provided
com.google.guava
@@ -179,6 +184,7 @@
com.fasterxml.jackson.module
jackson-module-jaxb-annotations
${jackson.version}
+ provided
diff --git a/core/raydp-main/pom.xml b/core/raydp-main/pom.xml
index 3c791a65..78effa21 100644
--- a/core/raydp-main/pom.xml
+++ b/core/raydp-main/pom.xml
@@ -134,24 +134,20 @@
com.fasterxml.jackson.core
jackson-core
- ${jackson.version}
com.fasterxml.jackson.core
jackson-databind
- ${jackson.version}
com.fasterxml.jackson.core
jackson-annotations
- ${jackson.version}
com.fasterxml.jackson.module
jackson-module-scala_${scala.binary.version}
- ${jackson.version}
com.google.guava
@@ -162,7 +158,6 @@
com.fasterxml.jackson.module
jackson-module-jaxb-annotations
- ${jackson.version}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
index f4cc823d..f835106a 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
@@ -22,15 +22,15 @@ import java.text.SimpleDateFormat
import java.util.{Date, Locale}
import javax.xml.bind.DatatypeConverter
-import scala.collection.JavaConverters._
import scala.collection.mutable.HashMap
+import scala.jdk.CollectionConverters._
+import com.fasterxml.jackson.core.JsonFactory
+import com.fasterxml.jackson.databind.ObjectMapper
import io.ray.api.{ActorHandle, PlacementGroups, Ray}
import io.ray.api.id.PlacementGroupId
import io.ray.api.placementgroup.PlacementGroup
import io.ray.runtime.config.RayConfig
-import org.json4s._
-import org.json4s.jackson.JsonMethods._
import org.apache.spark.{RayDPException, SecurityManager, SparkConf}
import org.apache.spark.executor.RayDPExecutor
@@ -39,6 +39,7 @@ import org.apache.spark.raydp.{RayExecutorUtils, SparkOnRayConfigs}
import org.apache.spark.rpc._
import org.apache.spark.util.Utils
+
class RayAppMaster(host: String,
port: Int,
actorExtraClasspath: String) extends Serializable with Logging {
@@ -298,7 +299,7 @@ class RayAppMaster(host: String,
.map { case (name, amount) => (name, Double.box(amount)) }.asJava,
placementGroup,
getNextBundleIndex,
- seqAsJavaList(appInfo.desc.command.javaOpts))
+ appInfo.desc.command.javaOpts.asJava)
appInfo.addPendingRegisterExecutor(executorId, handler, sparkCoresPerExecutor, memory)
}
@@ -356,11 +357,15 @@ object RayAppMaster extends Serializable {
val ACTOR_NAME = "RAY_APP_MASTER"
def setProperties(properties: String): Unit = {
- implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
- val parsed = parse(properties).extract[Map[String, String]]
- parsed.foreach{ case (key, value) =>
- System.setProperty(key, value)
+ // Use Jackson ObjectMapper directly to avoid JSON4S version conflicts
+ val mapper = new ObjectMapper()
+ val javaMap = mapper.readValue(properties, classOf[java.util.Map[String, Object]])
+ val scalaMap = javaMap.asScala.toMap
+ scalaMap.foreach{ case (key, value) =>
+ // Convert all values to strings since System.setProperty expects String
+ System.setProperty(key, value.toString)
}
+
// Use the same session dir as the python side
RayConfig.create().setSessionDir(System.getProperty("ray.session-dir"))
}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
index 949618e5..6905e711 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
@@ -27,9 +27,9 @@ import java.util.function.{Function => JFunction}
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types.pojo.Schema
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
import org.apache.spark.{RayDPException, SparkContext}
import org.apache.spark.deploy.raydp._
@@ -85,13 +85,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
* Save the DataFrame to Ray object store with Apache Arrow format.
*/
def save(useBatch: Boolean, ownerName: String): List[RecordBatch] = {
- val conf = df.queryExecution.sparkSession.sessionState.conf
+ val sparkSession = df.sparkSession
+ val conf = sparkSession.sessionState.conf
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
var batchSize = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
if (!useBatch) {
batchSize = 0
}
val schema = df.schema
+ val arrowSchemaJson = SparkShimLoader.getSparkShims.toArrowSchema(
+ schema, timeZoneId, sparkSession).toJson
val objectIds = df.queryExecution.toRdd.mapPartitions{ iter =>
val queue = ObjectRefHolder.getQueue(uuid)
@@ -103,7 +106,8 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
Iterator(iter)
}
- val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId)
+ // Reconstruct arrow schema from JSON on executor
+ val arrowSchema = Schema.fromJSON(arrowSchemaJson)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"ray object store writer", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
@@ -213,9 +217,9 @@ object ObjectStoreWriter {
}
def toArrowSchema(df: DataFrame): Schema = {
- val conf = df.queryExecution.sparkSession.sessionState.conf
+ val conf = df.sparkSession.sessionState.conf
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
- SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId)
+ SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, df.sparkSession)
}
def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = {
@@ -225,10 +229,10 @@ object ObjectStoreWriter {
}
val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID())
val queue = ObjectRefHolder.getQueue(uuid)
- val rdd = df.toArrowBatchRdd
+ val rdd = SparkShimLoader.getSparkShims.toArrowBatchRDD(df)
rdd.persist(storageLevel)
rdd.count()
- var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
+ val executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray
val numExecutors = executorIds.length
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
.get.asInstanceOf[ActorHandle[RayAppMaster]]
diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index 2ca83522..c1f47fc2 100644
--- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -21,6 +21,7 @@ import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.executor.RayDPExecutorBackendFactory
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -39,5 +40,7 @@ trait SparkShims {
def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
+ def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema
+
+ def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]]
}
diff --git a/core/shims/pom.xml b/core/shims/pom.xml
index c013538b..ac16dba7 100644
--- a/core/shims/pom.xml
+++ b/core/shims/pom.xml
@@ -21,10 +21,11 @@
spark330
spark340
spark350
+ spark400
- 2.12
+ 2.13
4.3.0
3.2.2
diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml
index faff6ac5..295b3d73 100644
--- a/core/shims/spark322/pom.xml
+++ b/core/shims/spark322/pom.xml
@@ -16,8 +16,7 @@
jar
- 2.12.15
- 2.13.5
+ 2.13.12
diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index 6ea817db..6c423e33 100644
--- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark322.SparkSqlUtils
import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
class Spark322Shims extends SparkShims {
@@ -46,7 +47,14 @@ class Spark322Shims extends SparkShims {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
+ override def toArrowSchema(
+ schema : StructType,
+ timeZoneId : String,
+ sparkSession: SparkSession) : Schema = {
+ SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, session = sparkSession)
+ }
+
+ override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
}
}
diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
index be9b409c..609c7112 100644
--- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.spark322
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
@@ -29,7 +30,11 @@ object SparkSqlUtils {
ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
}
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
+ def toArrowSchema(schema : StructType, timeZoneId : String, session: SparkSession) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
+
+ def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = {
+ dataFrame.toArrowBatchRdd
+ }
}
diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml
index 4443f658..6972fef1 100644
--- a/core/shims/spark330/pom.xml
+++ b/core/shims/spark330/pom.xml
@@ -16,8 +16,7 @@
jar
- 2.12.15
- 2.13.5
+ 2.13.12
diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index 4f1a50b5..26197052 100644
--- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark330.SparkSqlUtils
import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
class Spark330Shims extends SparkShims {
@@ -46,7 +47,18 @@ class Spark330Shims extends SparkShims {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
+ override def toArrowSchema(
+ schema : StructType,
+ timeZoneId : String,
+ sparkSession: SparkSession) : Schema = {
+ SparkSqlUtils.toArrowSchema(
+ schema = schema,
+ timeZoneId = timeZoneId,
+ sparkSession = sparkSession
+ )
+ }
+
+ override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
}
}
diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
index 162371ad..8c937dcd 100644
--- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.spark330
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
@@ -29,7 +30,11 @@ object SparkSqlUtils {
ArrowConverters.toDataFrame(rdd, schema, session)
}
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
+ def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
+
+ def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = {
+ dataFrame.toArrowBatchRdd
+ }
}
diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml
index 1b312747..52af6ed5 100644
--- a/core/shims/spark340/pom.xml
+++ b/core/shims/spark340/pom.xml
@@ -16,8 +16,7 @@
jar
- 2.12.15
- 2.13.5
+ 2.13.12
diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index c444373f..26717840 100644
--- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark340.SparkSqlUtils
import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
class Spark340Shims extends SparkShims {
@@ -46,7 +47,18 @@ class Spark340Shims extends SparkShims {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
+ override def toArrowSchema(
+ schema : StructType,
+ timeZoneId : String,
+ sparkSession: SparkSession) : Schema = {
+ SparkSqlUtils.toArrowSchema(
+ schema = schema,
+ timeZoneId = timeZoneId,
+ sparkSession = sparkSession
+ )
+ }
+
+ override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
}
}
diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
index eb52d8e7..3ec33569 100644
--- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.spark340
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
@@ -39,7 +40,11 @@ object SparkSqlUtils {
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
+ def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
+
+ def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
+ }
}
diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml
index 2368daa2..0afa3289 100644
--- a/core/shims/spark350/pom.xml
+++ b/core/shims/spark350/pom.xml
@@ -16,8 +16,7 @@
jar
- 2.12.15
- 2.13.5
+ 2.13.12
diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index 721d6923..5b2f2eec 100644
--- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark350.SparkSqlUtils
import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
class Spark350Shims extends SparkShims {
@@ -46,7 +47,17 @@ class Spark350Shims extends SparkShims {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
+ override def toArrowSchema(schema : StructType,
+ timeZoneId : String,
+ sparkSession: SparkSession) : Schema = {
+ SparkSqlUtils.toArrowSchema(
+ schema = schema,
+ timeZoneId = timeZoneId,
+ sparkSession = sparkSession
+ )
+ }
+
+ override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
}
}
diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
index dfd063f7..a12c4256 100644
--- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.spark350
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
@@ -39,7 +40,18 @@ object SparkSqlUtils {
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
+ def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = {
+ val errorOnDuplicatedFieldNames =
+ sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+
+ ArrowUtils.toArrowSchema(
+ schema = schema,
+ timeZoneId = timeZoneId,
+ errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames
+ )
+ }
+
+ def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = {
+ dataFrame.toArrowBatchRdd
}
}
diff --git a/core/shims/spark400/pom.xml b/core/shims/spark400/pom.xml
new file mode 100644
index 00000000..fd3f8494
--- /dev/null
+++ b/core/shims/spark400/pom.xml
@@ -0,0 +1,98 @@
+
+
+
+ 4.0.0
+
+
+ com.intel
+ raydp-shims
+ 1.7.0-SNAPSHOT
+ ../pom.xml
+
+
+ raydp-shims-spark400
+ RayDP Shims for Spark 4.0.0
+ jar
+
+
+ 2.13.12
+
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.2
+
+
+ scala-compile-first
+ process-resources
+
+ compile
+
+
+
+ scala-test-compile-first
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+
+
+
+ src/main/resources
+
+
+
+
+
+
+ com.intel
+ raydp-shims-common
+ ${project.version}
+ compile
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark400.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark400.version}
+ provided
+
+
+ org.xerial.snappy
+ snappy-java
+
+
+ io.netty
+ netty-handler
+
+
+
+
+ org.xerial.snappy
+ snappy-java
+ ${snappy.version}
+
+
+ io.netty
+ netty-handler
+ ${netty.version}
+
+
+
diff --git a/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
new file mode 100644
index 00000000..f88bbd7a
--- /dev/null
+++ b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
@@ -0,0 +1 @@
+com.intel.raydp.shims.spark400.SparkShimProvider
diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
new file mode 100644
index 00000000..6652c182
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.raydp.shims.spark400
+
+import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
+
+object SparkShimProvider {
+ val SPARK400_DESCRIPTOR = SparkShimDescriptor(4, 0, 0)
+ val SPARK401_DESCRIPTOR = SparkShimDescriptor(4, 0, 1)
+ val DESCRIPTOR_STRINGS = Seq(
+ s"$SPARK400_DESCRIPTOR", s"$SPARK401_DESCRIPTOR"
+ )
+ val DESCRIPTOR = SPARK400_DESCRIPTOR
+}
+
+class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
+ def createShim: SparkShims = {
+ new Spark400Shims()
+ }
+
+ def matches(version: String): Boolean = {
+ SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
+ }
+}
diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala
new file mode 100644
index 00000000..540edd2f
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.raydp.shims.spark400
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.executor.{RayDPExecutorBackendFactory, RayDPSpark400ExecutorBackendFactory}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.spark400.SparkSqlUtils
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.spark400.TaskContextUtils
+import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
+import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.arrow.ArrowConverters
+
+import scala.reflect.ClassTag
+
+class Spark400Shims extends SparkShims {
+ override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
+
+ override def toDataFrame(
+ rdd: JavaRDD[Array[Byte]],
+ schema: String,
+ session: SparkSession): DataFrame = {
+ SparkSqlUtils.toDataFrame(rdd, schema, session)
+ }
+
+ override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
+ new RayDPSpark400ExecutorBackendFactory()
+ }
+
+ override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
+ TaskContextUtils.getDummyTaskContext(partitionId, env)
+ }
+
+ override def toArrowSchema(
+ schema: StructType,
+ timeZoneId: String,
+ sparkSession: SparkSession): Schema = {
+ SparkSqlUtils.toArrowSchema(schema, timeZoneId, sparkSession)
+ }
+
+ override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = {
+ SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession)
+ }
+}
diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala
new file mode 100644
index 00000000..287105cd
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.spark400
+
+import java.util.Properties
+
+import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
+import org.apache.spark.memory.TaskMemoryManager
+
+object TaskContextUtils {
+ def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
+ new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
+ new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
+ }
+}
diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
new file mode 100644
index 00000000..2e6b5e25
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.net.URL
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.rpc.RpcEnv
+
+class RayCoarseGrainedExecutorBackend(
+ rpcEnv: RpcEnv,
+ driverUrl: String,
+ executorId: String,
+ bindAddress: String,
+ hostname: String,
+ cores: Int,
+ userClassPath: Seq[URL],
+ env: SparkEnv,
+ resourcesFileOpt: Option[String],
+ resourceProfile: ResourceProfile)
+ extends CoarseGrainedExecutorBackend(
+ rpcEnv,
+ driverUrl,
+ executorId,
+ bindAddress,
+ hostname,
+ cores,
+ env,
+ resourcesFileOpt,
+ resourceProfile) {
+
+ override def getUserClassPath: Seq[URL] = userClassPath
+
+}
diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala
new file mode 100644
index 00000000..eed998bd
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.rpc.RpcEnv
+
+import java.net.URL
+
+class RayDPSpark400ExecutorBackendFactory
+ extends RayDPExecutorBackendFactory {
+ override def createExecutorBackend(
+ rpcEnv: RpcEnv,
+ driverUrl: String,
+ executorId: String,
+ bindAddress: String,
+ hostname: String,
+ cores: Int,
+ userClassPath: Seq[URL],
+ env: SparkEnv,
+ resourcesFileOpt: Option[String],
+ resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
+ new RayCoarseGrainedExecutorBackend(
+ rpcEnv,
+ driverUrl,
+ executorId,
+ bindAddress,
+ hostname,
+ cores,
+ userClassPath,
+ env,
+ resourcesFileOpt,
+ resourceProfile)
+ }
+}
diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
new file mode 100644
index 00000000..aab0e2fe
--- /dev/null
+++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.spark400
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.TaskContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.classic.ClassicConversions.castToImpl
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+
+object SparkSqlUtils {
+ def toDataFrame(
+ arrowBatchRDD: JavaRDD[Array[Byte]],
+ schemaString: String,
+ session: SparkSession): DataFrame = {
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
+ val internalRowRdd = arrowBatchRDD.rdd.mapPartitions { iter =>
+ val context = TaskContext.get()
+ ArrowConverters.fromBatchIterator(
+ arrowBatchIter = iter,
+ schema = schema,
+ timeZoneId = timeZoneId,
+ errorOnDuplicatedFieldNames = false,
+ largeVarTypes = false,
+ context = context)
+ }
+ session.internalCreateDataFrame(internalRowRdd.setName("arrow"), schema)
+ }
+
+ def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = {
+ dataFrame.toArrowBatchRdd
+ }
+
+ def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = {
+ val errorOnDuplicatedFieldNames =
+ sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
+ val largeVarTypes =
+ sparkSession.sessionState.conf.arrowUseLargeVarTypes
+
+ ArrowUtils.toArrowSchema(
+ schema = schema,
+ timeZoneId = timeZoneId,
+ errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames,
+ largeVarTypes = largeVarTypes
+ )
+ }
+}
diff --git a/python/raydp/tests/conftest.py b/python/raydp/tests/conftest.py
index 79c2d115..e0df0e35 100644
--- a/python/raydp/tests/conftest.py
+++ b/python/raydp/tests/conftest.py
@@ -64,7 +64,8 @@ def jdk17_extra_spark_configs() -> Dict[str, str]:
@pytest.fixture(scope="function")
def spark_session(request, jdk17_extra_spark_configs):
- builder = SparkSession.builder.master("local[2]").appName("RayDP test")
+ builder = SparkSession.builder.master("local[2]").appName("RayDP test") \
+ .config("spark.sql.ansi.enabled", "false")
for k, v in jdk17_extra_spark_configs.items():
builder = builder.config(k, v)
spark = builder.getOrCreate()
@@ -94,6 +95,7 @@ def spark_on_ray_small(request, jdk17_extra_spark_configs):
extra_configs = {
"spark.driver.host": node_ip,
"spark.driver.bindAddress": node_ip,
+ "spark.sql.ansi.enabled": "false",
**jdk17_extra_spark_configs
}
spark = raydp.init_spark("test", 1, 1, "500M", configs=extra_configs)
@@ -119,6 +121,7 @@ def spark_on_ray_2_executors(request, jdk17_extra_spark_configs):
extra_configs = {
"spark.driver.host": node_ip,
"spark.driver.bindAddress": node_ip,
+ "spark.sql.ansi.enabled": "false",
**jdk17_extra_spark_configs
}
spark = raydp.init_spark("test", 2, 1, "500M", configs=extra_configs)
diff --git a/python/raydp/tests/test_spark_cluster.py b/python/raydp/tests/test_spark_cluster.py
index cd279cf4..c9872dd5 100644
--- a/python/raydp/tests/test_spark_cluster.py
+++ b/python/raydp/tests/test_spark_cluster.py
@@ -155,6 +155,7 @@ def test_ray_dataset_roundtrip(jdk17_extra_spark_configs):
# always get the same sparkContext between tests.
# So we need to re-set the resource explicitly here.
"spark.ray.raydp_spark_executor.actor.resource.spark_executor": "0",
+ "spark.sql.ansi.enabled": "false",
**jdk17_extra_spark_configs
}
spark = raydp.init_spark(app_name="test_ray_dataset_roundtrip", num_executors=2,
diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py
index 9fa21f1d..6fb3244c 100644
--- a/python/raydp/tf/estimator.py
+++ b/python/raydp/tf/estimator.py
@@ -17,6 +17,7 @@
from packaging import version
import tempfile
+import warnings
from typing import Any, List, NoReturn, Optional, Union, Dict
import tensorflow as tf
@@ -188,11 +189,15 @@ def train_func(config):
if config["evaluate"]:
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
results.append(test_history)
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
- checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- session.report({}, checkpoint=checkpoint)
+ # Only save checkpoint from the chief worker to avoid race conditions
+ checkpoint = None
+ if session.get_world_rank() == 0:
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
+ multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
+
+ session.report({}, checkpoint=checkpoint)
def fit(self,
train_ds: Dataset,
diff --git a/python/setup.py b/python/setup.py
index 49c5b8c7..0ce7ba33 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -100,7 +100,7 @@ def run(self):
"psutil",
"pyarrow >= 4.0.1",
"ray >= 2.1.0",
- "pyspark >= 3.1.1, <=3.5.7",
+ "pyspark >= 4.0.0",
"netifaces",
"protobuf > 3.19.5"
]