diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index fa3391f7e1cc8..8aea4919ff5a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -189,6 +189,17 @@ class JDBCRDD( sparkContext, name = "JDBC query execution time") + /** + * Time needed to fetch the data and transform it into Spark's InternalRow format. + * + * Usually this is spent in network transfer time, but it can be spent in transformation time + * as well if we are transforming some more complex datatype such as structs. + */ + val fetchAndTransformToInternalRowsMetric: SQLMetric = SQLMetrics.createNanoTimingMetric( + sparkContext, + // Message that user sees does not have to leak details about conversion + name = "JDBC remote data fetch and translation time") + private lazy val dialect = JdbcDialects.get(url) def generateJdbcQuery(partition: Option[JDBCPartition]): String = { @@ -301,23 +312,25 @@ class JDBCRDD( stmt.setFetchSize(options.fetchSize) stmt.setQueryTimeout(options.queryTimeout) - val startTime = System.nanoTime - rs = try { - stmt.executeQuery() - } catch { - case e: SQLException if dialect.isSyntaxErrorBestEffort(e) => - throw new SparkException( - errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION", - messageParameters = Map("jdbcQuery" -> sqlText), - cause = e) + rs = SQLMetrics.withTimingNs(queryExecutionTimeMetric) { + try { + stmt.executeQuery() + } catch { + case e: SQLException if dialect.isSyntaxErrorBestEffort(e) => + throw new SparkException( + errorClass = "JDBC_EXTERNAL_ENGINE_SYNTAX_ERROR.DURING_QUERY_EXECUTION", + messageParameters = Map("jdbcQuery" -> sqlText), + cause = e) + } } - val endTime = System.nanoTime - - val executionTime = endTime - startTime - queryExecutionTimeMetric.add(executionTime) val rowsIterator = - JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics) + JdbcUtils.resultSetToSparkInternalRows( + rs, + dialect, + schema, + inputMetrics, + Some(fetchAndTransformToInternalRowsMetric)) CompletionIterator[InternalRow, Iterator[InternalRow]]( new InterruptibleIterator(context, rowsIterator), close()) @@ -325,6 +338,7 @@ class JDBCRDD( override def getMetrics: Seq[(String, SQLMetric)] = { Seq( + "fetchAndTransformToInternalRowsNs" -> fetchAndTransformToInternalRowsMetric, "queryExecutionTime" -> queryExecutionTimeMetric ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0077012e2b0e4..f34299b08726c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -357,7 +358,8 @@ object JdbcUtils extends Logging with SQLConfHelper { resultSet: ResultSet, dialect: JdbcDialect, schema: StructType, - inputMetrics: InputMetrics): Iterator[InternalRow] = { + inputMetrics: InputMetrics, + fetchAndTransformToInternalRowsMetric: Option[SQLMetric] = None): Iterator[InternalRow] = { new NextIterator[InternalRow] { private[this] val rs = resultSet private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema) @@ -372,7 +374,7 @@ object JdbcUtils extends Logging with SQLConfHelper { } } - override protected def getNext(): InternalRow = { + private def getNextWithoutTiming: InternalRow = { if (rs.next()) { inputMetrics.incRecordsRead(1) var i = 0 @@ -387,6 +389,16 @@ object JdbcUtils extends Logging with SQLConfHelper { null.asInstanceOf[InternalRow] } } + + override protected def getNext(): InternalRow = { + if (fetchAndTransformToInternalRowsMetric.isDefined) { + SQLMetrics.withTimingNs(fetchAndTransformToInternalRowsMetric.get) { + getNextWithoutTiming + } + } else { + getNextWithoutTiming + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 065c8db7ac6f9..13f4d7926bea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -221,4 +221,19 @@ object SQLMetrics { SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) } } + + /** + * Measures the time taken by the function `f` in nanoseconds and adds it to the provided metric. + * + * @param metric SQLMetric to record the time taken. + * @param f Function/Codeblock to execute and measure. + * @return The result of the function `f`. + */ + def withTimingNs[T](metric: SQLMetric)(f: => T): T = { + val startTime = System.nanoTime() + val result = f + val endTime = System.nanoTime() + metric.add(endTime - startTime) + result + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 29481599362a4..36604adbd48e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -987,6 +987,18 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").toInfoUpdate.update === Some(-1)) assert(SQLMetrics.createMetric(sparkContext, name = "m").toInfoUpdate.update === Some(0)) } + + test("withTimingNs should time and return same result") { + val metric = SQLMetrics.createTimingMetric(sparkContext, name = "m") + + // Use a simple block that returns a value + val result = SQLMetrics.withTimingNs(metric) { + 42 + } + + assert(result === 42) + assert(!metric.isZero, "Metric was not increased") + } } case class CustomFileCommitProtocol(