diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 0a75563a..69bb51a5 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -32,7 +32,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [3.9, 3.10.14] - spark-version: [3.3.2, 3.4.0, 3.5.0] + spark-version: [4.0.0] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index 999f184d..8b7f6067 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -33,7 +33,7 @@ jobs: matrix: os: [ubuntu-latest] python-version: [3.9, 3.10.14] - spark-version: [3.3.2, 3.4.0, 3.5.0] + spark-version: [4.0.0] ray-version: [2.34.0, 2.40.0] runs-on: ${{ matrix.os }} @@ -74,8 +74,8 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install "numpy<1.24" "click<8.3.0" - pip install "pydantic<2.0" + pip install "numpy<1.24" + pip install "pydantic<2.0" "click<8.3.0" SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') if [ "$(uname -s)" == "Linux" ] then diff --git a/core/pom.xml b/core/pom.xml index 5e1185d9..6662cd5e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -18,6 +18,7 @@ 3.3.0 3.4.0 3.5.0 + 4.0.0 1.1.10.4 4.1.94.Final 1.10.0 @@ -29,9 +30,9 @@ UTF-8 1.8 1.8 - 2.12.15 - 2.15.0 - 2.12 + 2.13.12 + 2.18.2 + 2.13 5.10.1 @@ -151,16 +152,19 @@ com.fasterxml.jackson.core jackson-core ${jackson.version} + provided com.fasterxml.jackson.core jackson-databind ${jackson.version} + provided com.fasterxml.jackson.core jackson-annotations ${jackson.version} + provided @@ -168,6 +172,7 @@ com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} ${jackson.version} + provided com.google.guava @@ -179,6 +184,7 @@ com.fasterxml.jackson.module jackson-module-jaxb-annotations ${jackson.version} + provided diff --git a/core/raydp-main/pom.xml b/core/raydp-main/pom.xml index 3c791a65..78effa21 100644 --- a/core/raydp-main/pom.xml +++ b/core/raydp-main/pom.xml @@ -134,24 +134,20 @@ com.fasterxml.jackson.core jackson-core - ${jackson.version} com.fasterxml.jackson.core jackson-databind - ${jackson.version} com.fasterxml.jackson.core jackson-annotations - ${jackson.version} com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} - ${jackson.version} com.google.guava @@ -162,7 +158,6 @@ com.fasterxml.jackson.module jackson-module-jaxb-annotations - ${jackson.version} diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala index f4cc823d..f835106a 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala @@ -22,15 +22,15 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale} import javax.xml.bind.DatatypeConverter -import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap +import scala.jdk.CollectionConverters._ +import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.databind.ObjectMapper import io.ray.api.{ActorHandle, PlacementGroups, Ray} import io.ray.api.id.PlacementGroupId import io.ray.api.placementgroup.PlacementGroup import io.ray.runtime.config.RayConfig -import org.json4s._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{RayDPException, SecurityManager, SparkConf} import org.apache.spark.executor.RayDPExecutor @@ -39,6 +39,7 @@ import org.apache.spark.raydp.{RayExecutorUtils, SparkOnRayConfigs} import org.apache.spark.rpc._ import org.apache.spark.util.Utils + class RayAppMaster(host: String, port: Int, actorExtraClasspath: String) extends Serializable with Logging { @@ -298,7 +299,7 @@ class RayAppMaster(host: String, .map { case (name, amount) => (name, Double.box(amount)) }.asJava, placementGroup, getNextBundleIndex, - seqAsJavaList(appInfo.desc.command.javaOpts)) + appInfo.desc.command.javaOpts.asJava) appInfo.addPendingRegisterExecutor(executorId, handler, sparkCoresPerExecutor, memory) } @@ -356,11 +357,15 @@ object RayAppMaster extends Serializable { val ACTOR_NAME = "RAY_APP_MASTER" def setProperties(properties: String): Unit = { - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - val parsed = parse(properties).extract[Map[String, String]] - parsed.foreach{ case (key, value) => - System.setProperty(key, value) + // Use Jackson ObjectMapper directly to avoid JSON4S version conflicts + val mapper = new ObjectMapper() + val javaMap = mapper.readValue(properties, classOf[java.util.Map[String, Object]]) + val scalaMap = javaMap.asScala.toMap + scalaMap.foreach{ case (key, value) => + // Convert all values to strings since System.setProperty expects String + System.setProperty(key, value.toString) } + // Use the same session dir as the python side RayConfig.create().setSessionDir(System.getProperty("ray.session-dir")) } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 949618e5..6905e711 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -27,9 +27,9 @@ import java.util.function.{Function => JFunction} import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.arrow.vector.types.pojo.Schema -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import org.apache.spark.{RayDPException, SparkContext} import org.apache.spark.deploy.raydp._ @@ -85,13 +85,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { * Save the DataFrame to Ray object store with Apache Arrow format. */ def save(useBatch: Boolean, ownerName: String): List[RecordBatch] = { - val conf = df.queryExecution.sparkSession.sessionState.conf + val sparkSession = df.sparkSession + val conf = sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) var batchSize = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) if (!useBatch) { batchSize = 0 } val schema = df.schema + val arrowSchemaJson = SparkShimLoader.getSparkShims.toArrowSchema( + schema, timeZoneId, sparkSession).toJson val objectIds = df.queryExecution.toRdd.mapPartitions{ iter => val queue = ObjectRefHolder.getQueue(uuid) @@ -103,7 +106,8 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { Iterator(iter) } - val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId) + // Reconstruct arrow schema from JSON on executor + val arrowSchema = Schema.fromJSON(arrowSchemaJson) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"ray object store writer", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) @@ -213,9 +217,9 @@ object ObjectStoreWriter { } def toArrowSchema(df: DataFrame): Schema = { - val conf = df.queryExecution.sparkSession.sessionState.conf + val conf = df.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) - SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId) + SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, df.sparkSession) } def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = { @@ -225,10 +229,10 @@ object ObjectStoreWriter { } val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID()) val queue = ObjectRefHolder.getQueue(uuid) - val rdd = df.toArrowBatchRdd + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRDD(df) rdd.persist(storageLevel) rdd.count() - var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray + val executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) .get.asInstanceOf[ActorHandle[RayAppMaster]] diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 2ca83522..c1f47fc2 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -21,6 +21,7 @@ import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.java.JavaRDD import org.apache.spark.executor.RayDPExecutorBackendFactory +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SparkSession} @@ -39,5 +40,7 @@ trait SparkShims { def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema + + def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] } diff --git a/core/shims/pom.xml b/core/shims/pom.xml index c013538b..ac16dba7 100644 --- a/core/shims/pom.xml +++ b/core/shims/pom.xml @@ -21,10 +21,11 @@ spark330 spark340 spark350 + spark400 - 2.12 + 2.13 4.3.0 3.2.2 diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml index faff6ac5..295b3d73 100644 --- a/core/shims/spark322/pom.xml +++ b/core/shims/spark322/pom.xml @@ -16,8 +16,7 @@ jar - 2.12.15 - 2.13.5 + 2.13.12 diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 6ea817db..6c423e33 100644 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark322.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark322Shims extends SparkShims { @@ -46,7 +47,14 @@ class Spark322Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, session = sparkSession) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index be9b409c..609c7112 100644 --- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.spark322 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType @@ -29,7 +30,11 @@ object SparkSqlUtils { ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, session: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } } diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml index 4443f658..6972fef1 100644 --- a/core/shims/spark330/pom.xml +++ b/core/shims/spark330/pom.xml @@ -16,8 +16,7 @@ jar - 2.12.15 - 2.13.5 + 2.13.12 diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 4f1a50b5..26197052 100644 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark330.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark330Shims extends SparkShims { @@ -46,7 +47,18 @@ class Spark330Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 162371ad..8c937dcd 100644 --- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.spark330 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType @@ -29,7 +30,11 @@ object SparkSqlUtils { ArrowConverters.toDataFrame(rdd, schema, session) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } } diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml index 1b312747..52af6ed5 100644 --- a/core/shims/spark340/pom.xml +++ b/core/shims/spark340/pom.xml @@ -16,8 +16,7 @@ jar - 2.12.15 - 2.13.5 + 2.13.12 diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala index c444373f..26717840 100644 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark340.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark340Shims extends SparkShims { @@ -46,7 +47,18 @@ class Spark340Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index eb52d8e7..3ec33569 100644 --- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.spark340 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ @@ -39,7 +40,11 @@ object SparkSqlUtils { session.internalCreateDataFrame(rdd.setName("arrow"), schema) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) + } } diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml index 2368daa2..0afa3289 100644 --- a/core/shims/spark350/pom.xml +++ b/core/shims/spark350/pom.xml @@ -16,8 +16,7 @@ jar - 2.12.15 - 2.13.5 + 2.13.12 diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 721d6923..5b2f2eec 100644 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark350.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark350Shims extends SparkShims { @@ -46,7 +47,17 @@ class Spark350Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema(schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index dfd063f7..a12c4256 100644 --- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.spark350 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ @@ -39,7 +40,18 @@ object SparkSqlUtils { session.internalCreateDataFrame(rdd.setName("arrow"), schema) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + + ArrowUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames + ) + } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd } } diff --git a/core/shims/spark400/pom.xml b/core/shims/spark400/pom.xml new file mode 100644 index 00000000..fd3f8494 --- /dev/null +++ b/core/shims/spark400/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + com.intel + raydp-shims + 1.7.0-SNAPSHOT + ../pom.xml + + + raydp-shims-spark400 + RayDP Shims for Spark 4.0.0 + jar + + + 2.13.12 + + + + + + org.scalastyle + scalastyle-maven-plugin + + + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + scala-compile-first + process-resources + + compile + + + + scala-test-compile-first + process-test-resources + + testCompile + + + + + + + + + src/main/resources + + + + + + + com.intel + raydp-shims-common + ${project.version} + compile + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark400.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark400.version} + provided + + + org.xerial.snappy + snappy-java + + + io.netty + netty-handler + + + + + org.xerial.snappy + snappy-java + ${snappy.version} + + + io.netty + netty-handler + ${netty.version} + + + diff --git a/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider new file mode 100644 index 00000000..f88bbd7a --- /dev/null +++ b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider @@ -0,0 +1 @@ +com.intel.raydp.shims.spark400.SparkShimProvider diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala new file mode 100644 index 00000000..6652c182 --- /dev/null +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark400 + +import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} + +object SparkShimProvider { + val SPARK400_DESCRIPTOR = SparkShimDescriptor(4, 0, 0) + val SPARK401_DESCRIPTOR = SparkShimDescriptor(4, 0, 1) + val DESCRIPTOR_STRINGS = Seq( + s"$SPARK400_DESCRIPTOR", s"$SPARK401_DESCRIPTOR" + ) + val DESCRIPTOR = SPARK400_DESCRIPTOR +} + +class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { + def createShim: SparkShims = { + new Spark400Shims() + } + + def matches(version: String): Boolean = { + SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) + } +} diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala new file mode 100644 index 00000000..540edd2f --- /dev/null +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark400 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.executor.{RayDPExecutorBackendFactory, RayDPSpark400ExecutorBackendFactory} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.spark400.SparkSqlUtils +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.spark400.TaskContextUtils +import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowConverters + +import scala.reflect.ClassTag + +class Spark400Shims extends SparkShims { + override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR + + override def toDataFrame( + rdd: JavaRDD[Array[Byte]], + schema: String, + session: SparkSession): DataFrame = { + SparkSqlUtils.toDataFrame(rdd, schema, session) + } + + override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { + new RayDPSpark400ExecutorBackendFactory() + } + + override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + TaskContextUtils.getDummyTaskContext(partitionId, env) + } + + override def toArrowSchema( + schema: StructType, + timeZoneId: String, + sparkSession: SparkSession): Schema = { + SparkSqlUtils.toArrowSchema(schema, timeZoneId, sparkSession) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala new file mode 100644 index 00000000..287105cd --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.spark400 + +import java.util.Properties + +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} +import org.apache.spark.memory.TaskMemoryManager + +object TaskContextUtils { + def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, + new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala new file mode 100644 index 00000000..2e6b5e25 --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +class RayCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + env, + resourcesFileOpt, + resourceProfile) { + + override def getUserClassPath: Seq[URL] = userClassPath + +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala new file mode 100644 index 00000000..eed998bd --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +import java.net.URL + +class RayDPSpark400ExecutorBackendFactory + extends RayDPExecutorBackendFactory { + override def createExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { + new RayCoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + userClassPath, + env, + resourcesFileOpt, + resourceProfile) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala new file mode 100644 index 00000000..aab0e2fe --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.spark400 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils + +object SparkSqlUtils { + def toDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], + schemaString: String, + session: SparkSession): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + val internalRowRdd = arrowBatchRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator( + arrowBatchIter = iter, + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = false, + largeVarTypes = false, + context = context) + } + session.internalCreateDataFrame(internalRowRdd.setName("arrow"), schema) + } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } + + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + val largeVarTypes = + sparkSession.sessionState.conf.arrowUseLargeVarTypes + + ArrowUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames, + largeVarTypes = largeVarTypes + ) + } +} diff --git a/python/raydp/tests/conftest.py b/python/raydp/tests/conftest.py index 79c2d115..e0df0e35 100644 --- a/python/raydp/tests/conftest.py +++ b/python/raydp/tests/conftest.py @@ -64,7 +64,8 @@ def jdk17_extra_spark_configs() -> Dict[str, str]: @pytest.fixture(scope="function") def spark_session(request, jdk17_extra_spark_configs): - builder = SparkSession.builder.master("local[2]").appName("RayDP test") + builder = SparkSession.builder.master("local[2]").appName("RayDP test") \ + .config("spark.sql.ansi.enabled", "false") for k, v in jdk17_extra_spark_configs.items(): builder = builder.config(k, v) spark = builder.getOrCreate() @@ -94,6 +95,7 @@ def spark_on_ray_small(request, jdk17_extra_spark_configs): extra_configs = { "spark.driver.host": node_ip, "spark.driver.bindAddress": node_ip, + "spark.sql.ansi.enabled": "false", **jdk17_extra_spark_configs } spark = raydp.init_spark("test", 1, 1, "500M", configs=extra_configs) @@ -119,6 +121,7 @@ def spark_on_ray_2_executors(request, jdk17_extra_spark_configs): extra_configs = { "spark.driver.host": node_ip, "spark.driver.bindAddress": node_ip, + "spark.sql.ansi.enabled": "false", **jdk17_extra_spark_configs } spark = raydp.init_spark("test", 2, 1, "500M", configs=extra_configs) diff --git a/python/raydp/tests/test_spark_cluster.py b/python/raydp/tests/test_spark_cluster.py index cd279cf4..c9872dd5 100644 --- a/python/raydp/tests/test_spark_cluster.py +++ b/python/raydp/tests/test_spark_cluster.py @@ -155,6 +155,7 @@ def test_ray_dataset_roundtrip(jdk17_extra_spark_configs): # always get the same sparkContext between tests. # So we need to re-set the resource explicitly here. "spark.ray.raydp_spark_executor.actor.resource.spark_executor": "0", + "spark.sql.ansi.enabled": "false", **jdk17_extra_spark_configs } spark = raydp.init_spark(app_name="test_ray_dataset_roundtrip", num_executors=2, diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 9fa21f1d..6fb3244c 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -17,6 +17,7 @@ from packaging import version import tempfile +import warnings from typing import Any, List, NoReturn, Optional, Union, Dict import tensorflow as tf @@ -188,11 +189,15 @@ def train_func(config): if config["evaluate"]: test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks) results.append(test_history) - with tempfile.TemporaryDirectory() as temp_checkpoint_dir: - multi_worker_model.save(temp_checkpoint_dir, save_format="tf") - checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) - session.report({}, checkpoint=checkpoint) + # Only save checkpoint from the chief worker to avoid race conditions + checkpoint = None + if session.get_world_rank() == 0: + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + multi_worker_model.save(temp_checkpoint_dir, save_format="tf") + checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) + + session.report({}, checkpoint=checkpoint) def fit(self, train_ds: Dataset, diff --git a/python/setup.py b/python/setup.py index 49c5b8c7..0ce7ba33 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,7 +100,7 @@ def run(self): "psutil", "pyarrow >= 4.0.1", "ray >= 2.1.0", - "pyspark >= 3.1.1, <=3.5.7", + "pyspark >= 4.0.0", "netifaces", "protobuf > 3.19.5" ]