Skip to content

[SPARK-52846][SQL] Add a metric in JDBCRDD for how long it takes to fetch the resultset #51536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ class JDBCRDD(
sparkContext,
name = "JDBC query execution time")

/**
* Time needed to fetch the data and transform it into Spark's InternalRow format.
*
* Usually this is spent in network transfer time, but it can be spent in transformation time
* as well if we are transforming some more complex datatype such as structs.
*/
val fetchAndTransformToInternalRowsMetric: SQLMetric = SQLMetrics.createNanoTimingMetric(
sparkContext,
// Message that user sees does not have to leak details about conversion
name = "Remote data fetch time over JDBC connection")

private lazy val dialect = JdbcDialects.get(url)

def generateJdbcQuery(partition: Option[JDBCPartition]): String = {
Expand Down Expand Up @@ -301,30 +312,33 @@ class JDBCRDD(
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)

val startTime = System.nanoTime
rs = try {
stmt.executeQuery()
} catch {
case e: SQLException if dialect.isSyntaxErrorBestEffort(e) =>
throw new SparkException(
errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION",
messageParameters = Map("jdbcQuery" -> sqlText),
cause = e)
rs = SQLMetrics.withTimingNs(queryExecutionTimeMetric) {
try {
stmt.executeQuery()
} catch {
case e: SQLException if dialect.isSyntaxErrorBestEffort(e) =>
throw new SparkException(
errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION",
messageParameters = Map("jdbcQuery" -> sqlText),
cause = e)
}
}
val endTime = System.nanoTime

val executionTime = endTime - startTime
queryExecutionTimeMetric.add(executionTime)

val rowsIterator =
JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)
JdbcUtils.resultSetToSparkInternalRows(
rs,
dialect,
schema,
inputMetrics,
Some(fetchAndTransformToInternalRowsMetric))

CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
}

override def getMetrics: Seq[(String, SQLMetric)] = {
Seq(
"fetchAndTransformToInternalRowsNs" -> fetchAndTransformToInternalRowsMetric,
"queryExecutionTime" -> queryExecutionTimeMetric
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
Expand Down Expand Up @@ -357,7 +358,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
resultSet: ResultSet,
dialect: JdbcDialect,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
inputMetrics: InputMetrics,
fetchAndTransformToInternalRowsMetric: Option[SQLMetric] = None): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema)
Expand All @@ -372,7 +374,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
}
}

override protected def getNext(): InternalRow = {
private def getNextWithoutTiming: InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
Expand All @@ -387,6 +389,16 @@ object JdbcUtils extends Logging with SQLConfHelper {
null.asInstanceOf[InternalRow]
}
}

override protected def getNext(): InternalRow = {
if (fetchAndTransformToInternalRowsMetric.isDefined) {
SQLMetrics.withTimingNs(fetchAndTransformToInternalRowsMetric.get) {
getNextWithoutTiming
}
} else {
getNextWithoutTiming
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,19 @@ object SQLMetrics {
SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value)))
}
}

/**
* Measures the time taken by the function `f` in nanoseconds and adds it to the provided metric.
*
* @param metric SQLMetric to record the time taken.
* @param f Function/Codeblock to execute and measure.
* @return The result of the function `f`.
*/
def withTimingNs[T](metric: SQLMetric)(f: => T): T = {
val startTime = System.nanoTime()
val result = f
val endTime = System.nanoTime()
metric.add(endTime - startTime)
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,18 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").toInfoUpdate.update === Some(-1))
assert(SQLMetrics.createMetric(sparkContext, name = "m").toInfoUpdate.update === Some(0))
}

test("withTimingNs should time and return same result") {
val metric = SQLMetrics.createTimingMetric(sparkContext, name = "m")

// Use a simple block that returns a value
val result = SQLMetrics.withTimingNs(metric) {
42
}

assert(result === 42)
assert(!metric.isZero, "Metric was not increased")
}
}

case class CustomFileCommitProtocol(
Expand Down