Skip to content

[SPARK-51243][CORE][ML] Configurable allow native BLAS #49986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ private[spark] trait SparkEnvUtils {
System.getenv("SPARK_TESTING") != null || System.getProperty("spark.testing") != null
}

/**
* Whether allow using native BLAS/LAPACK/ARPACK libraries if available.
*/
val allowNativeBlas = "true".equals(System.getProperty("netlib.allowNativeBlas", "true"))
}

object SparkEnvUtils extends SparkEnvUtils
40 changes: 22 additions & 18 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ class SparkContext(config: SparkConf) extends Logging {

SparkContext.supplementJavaModuleOptions(_conf)
SparkContext.supplementJavaIPv6Options(_conf)
SparkContext.supplementBlasOptions(_conf)

_driverLogger = DriverLogger(_conf)

Expand Down Expand Up @@ -3409,32 +3410,35 @@ object SparkContext extends Logging {
}
}

private def supplementJavaOpts(
conf: SparkConf, key: OptionalConfigEntry[String], javaOpts: String): Unit = {
val v = conf.get(key) match {
case Some(opts) => s"$javaOpts $opts"
case None => javaOpts
}
conf.set(key.key, v)
}

/**
* SPARK-36796: This is a helper function to supplement some JVM runtime options to
* `spark.driver.extraJavaOptions` and `spark.executor.extraJavaOptions`.
*/
private def supplementJavaModuleOptions(conf: SparkConf): Unit = {
def supplement(key: OptionalConfigEntry[String]): Unit = {
val v = conf.get(key) match {
case Some(opts) => s"${JavaModuleOptions.defaultModuleOptions()} $opts"
case None => JavaModuleOptions.defaultModuleOptions()
}
conf.set(key.key, v)
}
supplement(DRIVER_JAVA_OPTIONS)
supplement(EXECUTOR_JAVA_OPTIONS)
val opts = JavaModuleOptions.defaultModuleOptions()
supplementJavaOpts(conf, DRIVER_JAVA_OPTIONS, opts)
supplementJavaOpts(conf, EXECUTOR_JAVA_OPTIONS, opts)
}

private def supplementJavaIPv6Options(conf: SparkConf): Unit = {
def supplement(key: OptionalConfigEntry[String]): Unit = {
val v = conf.get(key) match {
case Some(opts) => s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6} $opts"
case None => s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}"
}
conf.set(key.key, v)
}
supplement(DRIVER_JAVA_OPTIONS)
supplement(EXECUTOR_JAVA_OPTIONS)
val opts = s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}"
supplementJavaOpts(conf, DRIVER_JAVA_OPTIONS, opts)
supplementJavaOpts(conf, EXECUTOR_JAVA_OPTIONS, opts)
}

private def supplementBlasOptions(conf: SparkConf): Unit = {
val opts = s"-Dnetlib.allowNativeBlas=${Utils.allowNativeBlas}"
supplementJavaOpts(conf, DRIVER_JAVA_OPTIONS, opts)
supplementJavaOpts(conf, EXECUTOR_JAVA_OPTIONS, opts)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2833,4 +2833,12 @@ package object config {
.checkValues(Set("connect", "classic"))
.createWithDefault(
if (sys.env.get("SPARK_CONNECT_MODE").contains("1")) "connect" else "classic")

private[spark] val SPARK_ML_ALLOW_NATIVE_BLAS =
ConfigBuilder("spark.ml.allowNativeBlas")
.doc("Whether allow using native BLAS/LAPACK/ARPACK implementations when native " +
"libraries are available. If disabled, always use Java implementations.")
.version("4.1.0")
.booleanConf
.createWithDefault(true)
}
5 changes: 3 additions & 2 deletions docs/ml-linalg-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ The installation should be done on all nodes of the cluster. Generic version of

For Debian / Ubuntu:
```
sudo apt-get install libopenblas-base
sudo update-alternatives --config libblas.so.3
sudo apt-get install libopenblas-dev
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libopenblas-base is removed in Debian 12 and Ubuntu 24.04, libopenblas-dev should be used instead.

https://github.com/luhenry/netlib/blob/6835050840ead1a724e2f305875c92d7cc21f834/.github/workflows/build-and-test.yml#L31

also update-alternatives --config libblas.so.3 does not work, and it's variant in different CPU-arch OS.

root@0bef5c80cdaa:/# update-alternatives --config libblas.so.3
update-alternatives: error: no alternatives for libblas.so.3

Given it already allows using -Ddev.ludovic.netlib.lapack.nativeLib=... to choose the native libraries, I would suggest eliminating how to use alternatives to manage the OS default library in our docs.

```
For CentOS / RHEL:
```
Expand Down Expand Up @@ -76,6 +75,8 @@ You can also point `dev.ludovic.netlib` to specific libraries names and paths. F

If native libraries are not properly configured in the system, the Java implementation (javaBLAS) will be used as fallback option.

You can also set spark conf `spark.ml.allowNativeBlas` or Java system property `netlib.allowNativeBlas` to `false` to disable native BLAS and always use the Java implementation.

## Spark Configuration

The default behavior of multi-threading in either Intel MKL or OpenBLAS may not be optimal with Spark's execution model [^1].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ private List<String> buildSparkSubmitCommand(Map<String, String> env)
config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));
}

if (config.containsKey("spark.ml.allowNativeBlas")) {
String allowNativeBlas = config.get("spark.ml.allowNativeBlas");
addOptionString(cmd, "-Dnetlib.allowNativeBlas=" + allowNativeBlas);
}

// SPARK-36796: Always add some JVM runtime default options to submit command
addOptionString(cmd, JavaModuleOptions.defaultModuleOptions());
addOptionString(cmd, "-Dderby.connection.requireAuthentication=false");
Expand Down
5 changes: 5 additions & 0 deletions mllib-local/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
<dependency>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still a bit worried about the dependency change.
Defer to @WeichenXu123 's review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dep change can be eliminated if logging is unnecessary.

<groupId>org.apache.spark</groupId>
<artifactId>spark-common-utils_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>

<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
Expand Down
11 changes: 9 additions & 2 deletions mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.ml.linalg

import dev.ludovic.netlib.blas.{BLAS => NetlibBLAS, JavaBLAS => NetlibJavaBLAS, NativeBLAS => NetlibNativeBLAS}

import org.apache.spark.internal.Logging
import org.apache.spark.util.SparkEnvUtils

/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[spark] object BLAS extends Serializable {
private[spark] object BLAS extends Serializable with Logging {

@transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
Expand All @@ -39,8 +42,12 @@ private[spark] object BLAS extends Serializable {
// For level-3 routines, we use the native BLAS.
private[spark] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS =
_nativeBLAS = if (SparkEnvUtils.allowNativeBlas) {
try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS }
} else {
logInfo("Disable native BLAS because netlib.allowNativeBlas is false.")
javaBLAS
}
}
_nativeBLAS
}
Expand Down
11 changes: 9 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.mllib.linalg

import dev.ludovic.netlib.arpack.{ARPACK => NetlibARPACK, JavaARPACK => NetlibJavaARPACK, NativeARPACK => NetlibNativeARPACK}

import org.apache.spark.internal.Logging
import org.apache.spark.util.SparkEnvUtils

/**
* ARPACK routines for MLlib's vectors and matrices.
*/
private[spark] object ARPACK extends Serializable {
private[spark] object ARPACK extends Serializable with Logging {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I move ARPACK and LAPACK to mllib-local, to align with BLAS?


@transient private var _javaARPACK: NetlibARPACK = _
@transient private var _nativeARPACK: NetlibARPACK = _
Expand All @@ -36,8 +39,12 @@ private[spark] object ARPACK extends Serializable {

private[spark] def nativeARPACK: NetlibARPACK = {
if (_nativeARPACK == null) {
_nativeARPACK =
_nativeARPACK = if (SparkEnvUtils.allowNativeBlas) {
try { NetlibNativeARPACK.getInstance } catch { case _: Throwable => javaARPACK }
} else {
logInfo("Disable native ARPACK because netlib.allowNativeBlas is false.")
javaARPACK
}
}
_nativeARPACK
}
Expand Down
32 changes: 1 addition & 31 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,14 @@

package org.apache.spark.mllib.linalg

import dev.ludovic.netlib.blas.{BLAS => NetlibBLAS, JavaBLAS => NetlibJavaBLAS, NativeBLAS => NetlibNativeBLAS}

import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.BLAS.{getBLAS, nativeBLAS}

/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[spark] object BLAS extends Serializable with Logging {

@transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the duplicated instance creation and call org.apache.spark.ml.linalg.BLAS

private val nativeL1Threshold: Int = 256

// For level-1 function dspmv, use javaBLAS for better performance.
private[spark] def javaBLAS: NetlibBLAS = {
if (_javaBLAS == null) {
_javaBLAS = NetlibJavaBLAS.getInstance
}
_javaBLAS
}

// For level-3 routines, we use the native BLAS.
private[spark] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS =
try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS }
}
_nativeBLAS
}

private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = {
if (vectorSize < nativeL1Threshold) {
javaBLAS
} else {
nativeBLAS
}
}

/**
* y += a * x
*/
Expand Down
11 changes: 9 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.mllib.linalg

import dev.ludovic.netlib.lapack.{JavaLAPACK => NetlibJavaLAPACK, LAPACK => NetlibLAPACK, NativeLAPACK => NetlibNativeLAPACK}

import org.apache.spark.internal.Logging
import org.apache.spark.util.SparkEnvUtils

/**
* LAPACK routines for MLlib's vectors and matrices.
*/
private[spark] object LAPACK extends Serializable {
private[spark] object LAPACK extends Serializable with Logging {

@transient private var _javaLAPACK: NetlibLAPACK = _
@transient private var _nativeLAPACK: NetlibLAPACK = _
Expand All @@ -36,8 +39,12 @@ private[spark] object LAPACK extends Serializable {

private[spark] def nativeLAPACK: NetlibLAPACK = {
if (_nativeLAPACK == null) {
_nativeLAPACK =
_nativeLAPACK = if (SparkEnvUtils.allowNativeBlas) {
try { NetlibNativeLAPACK.getInstance } catch { case _: Throwable => javaLAPACK }
} else {
logInfo("Disable native LAPACK because netlib.allowNativeBlas is false.")
javaLAPACK
}
}
_nativeLAPACK
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class DenseMatrix @Since("1.3.0") (
if (isTransposed) {
Iterator.tabulate(numCols) { j =>
val col = new Array[Double](numRows)
BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1)
newlinalg.BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1)
new DenseVector(col)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.mllib.linalg.BLAS
import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ import org.apache.spark.mllib.util.TestingUtils._

class BLASSuite extends SparkFunSuite {

test("nativeL1Threshold") {
assert(getBLAS(128) == BLAS.javaBLAS)
assert(getBLAS(256) == BLAS.nativeBLAS)
assert(getBLAS(512) == BLAS.nativeBLAS)
}

test("copy") {
val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,8 @@ private[spark] class Client(

javaOpts += s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}"

javaOpts += s"-Dnetlib.allowNativeBlas=${sparkConf.get(SPARK_ML_ALLOW_NATIVE_BLAS)}"

// SPARK-37106: To start AM with Java 17, `JavaModuleOptions.defaultModuleOptions`
// is added by default.
javaOpts += JavaModuleOptions.defaultModuleOptions()
Expand Down