From d0739fb9c642de848f8b949052395f3e2efbad06 Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Tue, 1 Apr 2025 18:47:45 -0400 Subject: [PATCH 1/4] Enable -Xsource:3-cross compiler flag This not only eases potential future Scala 3 migration but also make the compiler stricter with features that have proven to be warts. --- .../spark/internal/config/ConfigBuilder.scala | 2 +- .../sql/kafka010/KafkaOffsetReaderAdmin.scala | 5 +- .../kafka010/KafkaOffsetReaderConsumer.scala | 5 +- .../kafka010/consumer/FetchedDataPool.scala | 2 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 8 +- .../kafka010/DirectKafkaInputDStream.scala | 2 +- .../streaming/kinesis/KinesisTestUtils.scala | 5 +- .../sql/protobuf/ProtobufSerializer.scala | 2 +- .../apache/spark/api/java/JavaHadoopRDD.scala | 2 +- .../spark/api/java/JavaNewHadoopRDD.scala | 2 +- .../apache/spark/api/java/JavaPairRDD.scala | 2 +- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 10 +- .../apache/spark/io/CompressionCodec.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../spark/rdd/ZippedPartitionsRDD.scala | 7 +- .../spark/resource/ResourceInformation.scala | 1 + .../apache/spark/resource/ResourceUtils.scala | 1 + .../spark/scheduler/TaskSchedulerImpl.scala | 5 +- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../org/apache/spark/status/api/v1/api.scala | 20 +- .../storage/BlockManagerMasterEndpoint.scala | 4 +- .../storage/ShuffleBlockFetcherIterator.scala | 69 ++++--- .../apache/spark/storage/StorageUtils.scala | 2 +- .../org/apache/spark/util/Distribution.scala | 2 +- .../org/apache/spark/util/SizeEstimator.scala | 5 +- .../org/apache/spark/util/ThreadUtils.scala | 1 + .../scala/org/apache/spark/util/Utils.scala | 2 +- .../util/logging/RollingFileAppender.scala | 5 +- .../org/apache/spark/CheckpointSuite.scala | 28 +-- .../history/BasicEventFilterSuite.scala | 2 +- .../history/HistoryServerPageSuite.scala | 1 + .../deploy/history/HistoryServerSuite.scala | 2 +- .../scala/org/apache/spark/rdd/RDDSuite.scala | 2 +- .../spark/scheduler/SparkListenerSuite.scala | 2 +- .../spark/util/ClosureCleanerSuite.scala | 2 +- .../spark/util/collection/SorterSuite.scala | 10 +- .../org/apache/spark/examples/SparkTC.scala | 5 +- .../graphx/impl/VertexPartitionBaseOps.scala | 4 +- .../lib/StronglyConnectedComponents.scala | 5 +- ...treamBasedCheckpointFileManagerSuite.scala | 2 +- .../org/apache/spark/ml/linalg/Matrices.scala | 12 +- .../scala/org/apache/spark/ml/Predictor.scala | 2 +- .../scala/org/apache/spark/ml/ann/Layer.scala | 2 +- .../spark/ml/classification/Classifier.scala | 4 +- .../DecisionTreeClassifier.scala | 4 +- .../ml/classification/GBTClassifier.scala | 3 +- .../spark/ml/classification/NaiveBayes.scala | 8 +- .../ProbabilisticClassifier.scala | 6 +- .../RandomForestClassifier.scala | 3 +- .../org/apache/spark/ml/clustering/LDA.scala | 2 +- .../ml/evaluation/ClusteringMetrics.scala | 4 +- .../apache/spark/ml/feature/Binarizer.scala | 4 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../spark/ml/feature/ElementwiseProduct.scala | 2 +- .../spark/ml/feature/FeatureHasher.scala | 2 +- .../apache/spark/ml/feature/HashingTF.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../apache/spark/ml/feature/Selector.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 8 +- .../spark/ml/feature/StopWordsRemover.scala | 4 +- .../spark/ml/feature/StringIndexer.scala | 6 +- .../spark/ml/feature/TargetEncoder.scala | 2 +- .../feature/UnivariateFeatureSelector.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../spark/ml/feature/VectorSizeHint.scala | 3 +- .../spark/ml/feature/VectorSlicer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 1 + .../spark/ml/linalg/JsonMatrixConverter.scala | 2 +- .../spark/ml/linalg/JsonVectorConverter.scala | 2 +- .../spark/ml/r/GaussianMixtureWrapper.scala | 2 +- .../org/apache/spark/ml/r/LDAWrapper.scala | 4 +- .../org/apache/spark/ml/r/RWrappers.scala | 1 + .../apache/spark/ml/recommendation/ALS.scala | 1 + .../ml/regression/DecisionTreeRegressor.scala | 8 +- .../spark/ml/regression/FMRegressor.scala | 1 + .../spark/ml/regression/GBTRegressor.scala | 6 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 5 +- .../spark/ml/tuning/CrossValidator.scala | 2 +- .../ml/tuning/TrainValidationSplit.scala | 2 +- .../classification/ClassificationModel.scala | 2 +- .../classification/LogisticRegression.scala | 17 +- .../spark/mllib/classification/SVM.scala | 8 +- .../clustering/GaussianMixtureModel.scala | 2 +- .../spark/mllib/clustering/LDAModel.scala | 3 +- .../spark/mllib/clustering/LDAOptimizer.scala | 1 + .../apache/spark/mllib/feature/Word2Vec.scala | 2 +- .../apache/spark/mllib/linalg/Matrices.scala | 10 +- .../apache/spark/mllib/linalg/Vectors.scala | 2 +- .../linalg/distributed/IndexedRowMatrix.scala | 2 +- .../apache/spark/mllib/regression/Lasso.scala | 4 +- .../mllib/regression/LinearRegression.scala | 2 +- .../mllib/regression/RegressionModel.scala | 2 +- .../mllib/regression/RidgeRegression.scala | 5 +- .../org/apache/spark/mllib/util/MLUtils.scala | 4 +- .../sql/ml/InternalFunctionRegistration.scala | 6 +- .../ml/classification/ClassifierSuite.scala | 2 +- .../ml/classification/FMClassifierSuite.scala | 2 +- .../classification/GBTClassifierSuite.scala | 2 +- .../ml/classification/LinearSVCSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 8 +- .../MultilayerPerceptronClassifierSuite.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 8 +- .../ml/clustering/GaussianMixtureSuite.scala | 8 +- .../BucketedRandomProjectionLSHSuite.scala | 2 +- .../spark/ml/feature/TargetEncoderSuite.scala | 4 +- .../ml/feature/VectorAssemblerSuite.scala | 2 +- .../spark/ml/feature/VectorIndexerSuite.scala | 20 +- .../AFTSurvivalRegressionSuite.scala | 4 +- .../org/apache/spark/ml/util/MLTest.scala | 4 +- .../apache/spark/ml/util/MLTestSuite.scala | 2 +- .../clustering/GaussianMixtureSuite.scala | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- pom.xml | 1 + project/SparkBuild.scala | 1 + .../k8s/SparkKubernetesClientFactory.scala | 12 +- .../KubernetesClusterSchedulerBackend.scala | 15 +- .../apache/spark/sql/MergeIntoWriter.scala | 4 +- .../sql/catalyst/util/RebaseDateTime.scala | 4 +- .../catalyst/util/TimestampFormatter.scala | 5 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../apache/spark/sql/types/StructType.scala | 3 +- .../spark/sql/catalyst/StructFilters.scala | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 7 +- .../NaturalAndUsingJoinResolution.scala | 2 +- .../analysis/RewriteMergeIntoTable.scala | 4 +- .../catalyst/analysis/TypeCoercionBase.scala | 2 +- .../analysis/higherOrderFunctions.scala | 4 +- .../resolver/RelationMetadataProvider.scala | 2 +- .../catalog/ExternalCatalogUtils.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 4 +- .../sql/catalyst/csv/CSVInferSchema.scala | 2 +- .../sql/catalyst/expressions/Between.scala | 2 +- .../catalyst/expressions/DynamicPruning.scala | 2 +- .../catalyst/expressions/ExpressionSet.scala | 2 +- ...ctionTableSubqueryArgumentExpression.scala | 1 - .../sql/catalyst/expressions/ScalaUDF.scala | 8 +- .../catalyst/expressions/SchemaPruning.scala | 2 +- .../catalyst/expressions/ToStringBase.scala | 2 +- .../expressions/aggregate/Count.scala | 6 +- .../expressions/datetimeExpressions.scala | 8 +- .../json/JsonExpressionEvalUtils.scala | 7 +- .../expressions/objects/objects.scala | 10 +- .../sql/catalyst/expressions/subquery.scala | 11 +- .../expressions/windowExpressions.scala | 8 +- .../optimizer/CostBasedJoinReorder.scala | 2 +- .../optimizer/DecorrelateInnerQuery.scala | 8 +- .../optimizer/NormalizeFloatingNumbers.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 77 +++---- .../optimizer/PushDownLeftSemiAntiJoin.scala | 8 +- .../PushExtraPredicateThroughJoin.scala | 4 +- ...wnPredicatesAndPruneColumnsForCTEDef.scala | 4 +- .../sql/catalyst/optimizer/expressions.scala | 26 +-- .../catalyst/optimizer/finishAnalysis.scala | 2 +- .../spark/sql/catalyst/optimizer/joins.scala | 10 +- .../sql/catalyst/optimizer/subquery.scala | 22 +- .../sql/catalyst/parser/AstBuilder.scala | 38 ++-- .../sql/catalyst/planning/patterns.scala | 4 +- .../sql/catalyst/plans/NormalizePlan.scala | 8 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../plans/logical/ColumnDefinition.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../sql/catalyst/types/PhysicalDataType.scala | 6 +- .../util/ResolveDefaultColumnsUtil.scala | 2 +- .../sql/catalyst/util/ToNumberParser.scala | 10 +- .../sql/catalyst/xml/StaxXmlParserUtils.scala | 15 +- .../spark/sql/CalendarIntervalBenchmark.scala | 2 +- .../org/apache/spark/sql/HashBenchmark.scala | 8 +- .../spark/sql/HashByteArrayBenchmark.scala | 6 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../analysis/AnsiTypeCoercionSuite.scala | 5 +- .../ExpressionTypeCheckingSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 15 +- .../ArithmeticExpressionSuite.scala | 6 +- .../expressions/BitwiseExpressionsSuite.scala | 14 +- .../expressions/CanonicalizeSuite.scala | 16 +- .../ConditionalExpressionSuite.scala | 2 +- .../expressions/DateExpressionsSuite.scala | 36 ++-- .../expressions/HashExpressionsSuite.scala | 6 +- .../expressions/MathExpressionsSuite.scala | 193 +++++++++--------- .../catalyst/expressions/PredicateSuite.scala | 58 +++--- .../aggregate/HistogramNumericSuite.scala | 2 +- .../xml/XPathExpressionSuite.scala | 14 +- .../optimizer/ComputeCurrentTimeSuite.scala | 8 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- ...ReplaceNullWithFalseInPredicateSuite.scala | 4 +- .../catalyst/parser/ParserUtilsSuite.scala | 2 +- .../catalyst/util/RebaseDateTimeSuite.scala | 10 +- .../catalog/EnumTypeSetBenchmark.scala | 12 +- .../connector/catalog/InMemoryBaseTable.scala | 3 +- .../UserDefinedFunctionE2ETestSuite.scala | 4 +- .../sql/connect/test/RemoteSparkSession.scala | 2 +- ...cutePlanResponseReattachableIterator.scala | 5 +- .../client/GrpcExceptionConverter.scala | 3 +- .../sql/connect/client/RetryPolicy.scala | 1 + .../client/arrow/ArrowSerializer.scala | 2 +- .../planner/SparkConnectPlannerSuite.scala | 2 +- ...parkConnectWithSessionExtensionSuite.scala | 2 +- .../spark/sql/avro/AvroSerializer.scala | 4 +- .../analysis/ResolveSessionCatalog.scala | 6 +- .../sql/columnar/CachedBatchSerializer.scala | 2 +- .../BaseScriptTransformationExec.scala | 4 +- .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/SortPrefixUtils.scala | 2 +- .../spark/sql/execution/SparkPlanner.scala | 2 +- .../adaptive/AQEShuffleReadExec.scala | 5 +- .../aggregate/MergingSessionsIterator.scala | 1 + .../command/DescribeRelationJsonCommand.scala | 2 +- .../spark/sql/execution/command/ddl.scala | 2 +- .../spark/sql/execution/command/tables.scala | 2 +- .../execution/datasources/DataSource.scala | 1 + .../datasources/DataSourceStrategy.scala | 4 +- .../datasources/DataSourceUtils.scala | 8 +- .../execution/datasources/FileFormat.scala | 14 +- .../sql/execution/datasources/FileIndex.scala | 2 +- .../PruneFileSourcePartitions.scala | 4 +- .../datasources/PushVariantIntoScan.scala | 2 +- .../execution/datasources/SchemaPruning.scala | 4 +- .../binaryfile/BinaryFileFormat.scala | 2 +- .../datasources/orc/OrcFilters.scala | 2 +- .../datasources/parquet/ParquetFilters.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 3 +- ...upBasedRowLevelOperationScanPlanning.scala | 2 +- .../datasources/v2/PushDownUtils.scala | 1 + .../v2/V2ScanRelationPushDown.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 15 +- .../sql/execution/python/EvaluatePython.scala | 2 +- .../WindowInPandasEvaluatorFactory.scala | 2 +- ...cProgressTrackingMicroBatchExecution.scala | 4 +- .../streaming/FileStreamSource.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 1 + .../StreamingSymmetricHashJoinHelper.scala | 6 +- .../streaming/TransformWithStateExec.scala | 1 + .../TransformWithStateVariableUtils.scala | 1 + .../continuous/ContinuousExecution.scala | 5 +- .../streaming/state/SchemaHelper.scala | 1 + .../StreamingAggregationStateManager.scala | 1 + .../StreamingSessionWindowStateManager.scala | 1 + .../state/SymmetricHashJoinStateManager.scala | 1 + .../window/WindowEvaluatorFactoryBase.scala | 14 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../apache/spark/sql/DatasetCacheSuite.scala | 6 +- .../apache/spark/sql/SQLQueryTestHelper.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 4 +- .../sql/SparkSessionExtensionSuite.scala | 56 ++--- .../sql/connector/DataSourceV2Suite.scala | 44 ++-- .../spark/sql/connector/FakeV2Provider.scala | 8 +- .../BaseScriptTransformationSuite.scala | 2 +- .../CoalesceShufflePartitionsSuite.scala | 24 +-- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 8 +- .../benchmark/ByteArrayBenchmark.scala | 4 +- .../ConstantColumnVectorBenchmark.scala | 24 +-- .../benchmark/PrimitiveArrayBenchmark.scala | 4 +- .../execution/benchmark/UDFBenchmark.scala | 2 +- .../benchmark/UnsafeArrayDataBenchmark.scala | 16 +- .../columnar/CachedBatchSerializerSuite.scala | 1 + .../CompressionSchemeBenchmark.scala | 4 +- .../FileSourceCustomMetadataStructSuite.scala | 4 +- .../datasources/orc/OrcV1FilterSuite.scala | 4 +- .../SQLLiveEntitiesEventFilterSuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../CheckpointFileManagerSuite.scala | 2 +- .../streaming/state/StateStoreSuite.scala | 1 + ...treamingAggregationStateManagerSuite.scala | 1 + .../ui/SQLAppStatusListenerSuite.scala | 3 +- .../vectorized/ColumnarBatchBenchmark.scala | 28 +-- .../spark/sql/sources/InsertSuite.scala | 6 +- .../sql/streaming/FileStreamSourceSuite.scala | 16 +- .../streaming/ReportSinkMetricsSuite.scala | 92 ++++----- .../spark/sql/streaming/StreamSuite.scala | 4 +- .../sql/streaming/StreamingQuerySuite.scala | 6 +- .../apache/spark/sql/test/SQLTestData.scala | 2 +- .../SqlResourceWithActualMetricsSuite.scala | 1 + .../spark/sql/hive/HiveInspectors.scala | 56 ++--- .../sql/hive/client/HiveClientImpl.scala | 5 +- .../CreateHiveTableAsSelectCommand.scala | 1 + .../hive/execution/HiveTableScanExec.scala | 2 +- .../execution/InsertIntoHiveDirCommand.scala | 1 + .../execution/PruneHiveTablePartitions.scala | 4 +- .../sql/sources/SimpleTextRelation.scala | 2 +- .../spark/streaming/CheckpointSuite.scala | 3 +- .../spark/streaming/DStreamClosureSuite.scala | 2 +- .../streaming/ReceivedBlockTrackerSuite.scala | 14 +- 294 files changed, 1064 insertions(+), 973 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index d3e975d1782f0..7756c5713ead1 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -94,7 +94,7 @@ private[spark] class TypedConfigBuilder[T]( import ConfigHelpers._ def this(parent: ConfigBuilder, converter: String => T) = { - this(parent, converter, { v: T => v.toString }) + this(parent, converter, (v: T) => v.toString) } /** Apply a transformation to the user-provided values of the config entry. */ diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index bfc6139bdb729..f903b34f840f1 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -336,7 +336,7 @@ private[kafka010] class KafkaOffsetReaderAdmin( // admin has a much bigger chance to hit KAFKA-7703 like issues. var incorrectOffsets: Seq[(TopicPartition, Long, Long)] = Nil var attempt = 0 - do { + while ({ partitionOffsets = listOffsets(admin, listOffsetsParams) attempt += 1 @@ -349,7 +349,8 @@ private[kafka010] class KafkaOffsetReaderAdmin( Thread.sleep(offsetFetchAttemptIntervalMs) } } - } while (incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts) + incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts + }) () logDebug(s"Got latest offsets for partitions: $partitionOffsets") partitionOffsets diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala index 7aadde7218f51..092f6bd966361 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala @@ -385,7 +385,7 @@ private[kafka010] class KafkaOffsetReaderConsumer( // - Avoid calling `consumer.poll(0)` which may cause KAFKA-7703. var incorrectOffsets: Seq[(TopicPartition, Long, Long)] = Nil var attempt = 0 - do { + while ({ consumer.seekToEnd(partitions) partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap attempt += 1 @@ -399,7 +399,8 @@ private[kafka010] class KafkaOffsetReaderConsumer( Thread.sleep(offsetFetchAttemptIntervalMs) } } - } while (incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts) + incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts + }) () logDebug(s"Got latest offsets for partition : $partitionOffsets") partitionOffsets diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/FetchedDataPool.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/FetchedDataPool.scala index 9f68cb6fd0882..f4c9c095bdec4 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/FetchedDataPool.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/FetchedDataPool.scala @@ -151,7 +151,7 @@ private[consumer] class FetchedDataPool( private def removeIdleFetchedData(): Unit = synchronized { val now = clock.getTimeMillis() val maxAllowedReleasedTimestamp = now - minEvictableIdleTimeMillis - cache.values.foreach { p: CachedFetchedDataList => + cache.values.foreach { (p: CachedFetchedDataList) => val expired = p.filter { q => !q.inUse && q.lastReleasedTimestamp < maxAllowedReleasedTimestamp } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e738abf21f597..bbb4bda8ab8b7 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -2362,7 +2362,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 30, 31, 32), // Add data when stream is stopped StartStream(), CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22, 30, 31, 32), // Should get the added data - AssertOnQuery("Add partitions") { query: StreamExecution => + AssertOnQuery("Add partitions") { (query: StreamExecution) => if (addPartitions) setTopicPartitions(topic, 10, query) true }, @@ -2413,7 +2413,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 30, 31, 32), // Add data when stream is stopped StartStream(), CheckAnswer(-21, -22, -11, -12, 2, 12, 23, 24, 30, 31, 32), // Should get the added data - AssertOnQuery("Add partitions") { query: StreamExecution => + AssertOnQuery("Add partitions") { (query: StreamExecution) => if (addPartitions) setTopicPartitions(topic, 10, query) true }, @@ -2581,7 +2581,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { CheckAnswer(2, 3, 4, 5, 6, 7), // Should get the added data AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), - AssertOnQuery("Add partitions") { query: StreamExecution => + AssertOnQuery("Add partitions") { (query: StreamExecution) => if (addPartitions) setTopicPartitions(topic, 10, query) true }, @@ -2622,7 +2622,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), - AssertOnQuery("Add partitions") { query: StreamExecution => + AssertOnQuery("Add partitions") { (query: StreamExecution) => if (addPartitions) setTopicPartitions(topic, 10, query) true }, diff --git a/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index f7bea064d2d6c..c1d0e4dd2e4b4 100644 --- a/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -112,7 +112,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( private[streaming] override def name: String = s"Kafka 0.10 direct stream [$id]" - protected[streaming] override val checkpointData = + protected[streaming] override val checkpointData: DirectKafkaInputDStreamCheckpointData = new DirectKafkaInputDStreamCheckpointData diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 652822c5fdc97..66932c9194e4b 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -176,10 +176,11 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi private def findNonExistentStreamName(): String = { var testStreamName: String = null - do { + while ({ Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" - } while (describeStream(testStreamName).nonEmpty) + describeStream(testStreamName).nonEmpty + }) () testStreamName } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala index 65e8cce0d056e..8d730285a8ae1 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -322,7 +322,7 @@ private[sql] class ProtobufSerializer( .unzip val numFields = catalystStruct.length - row: InternalRow => + (row: InternalRow) => val result = DynamicMessage.newBuilder(descriptor) var i = 0 while (i < numFields) { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 3e8911244c016..375444ec52c6b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.HadoopRDD @DeveloperApi class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) - (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + (implicit override val kClassTag: ClassTag[K], override val vClassTag: ClassTag[V]) extends JavaPairRDD[K, V](rdd) { /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index 936e7b684a5be..8afc248b650f1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.NewHadoopRDD @DeveloperApi class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) - (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + (implicit override val kClassTag: ClassTag[K], override val vClassTag: ClassTag[V]) extends JavaPairRDD[K, V](rdd) { /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index f16c0be75c6e8..7bac7f3d2cb9b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) - (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) + (implicit val kClassTag: ClassTag[K], val vClassTag: ClassTag[V]) extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 85e4ebd707cfc..0d35bbec64f65 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -30,7 +30,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.sys.process._ -import org.json4s.Formats +import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkConf, SparkContext} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a30759e5d794e..18e2bbec724b0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -408,11 +408,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { resourceProfileId: Int) def main(args: Array[String]): Unit = { - val createFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) => - CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) => - new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, - arguments.bindAddress, arguments.hostname, arguments.cores, - env, arguments.resourcesFileOpt, resourceProfile) + val createFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) => CoarseGrainedExecutorBackend = { + case (rpcEnv, arguments, env, resourceProfile) => + new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, + arguments.bindAddress, arguments.hostname, arguments.cores, + env, arguments.resourcesFileOpt, resourceProfile) } run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn) System.exit(0) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 233228a9c6d4c..6127b6d1cffde 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -244,7 +244,7 @@ class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { new BufferedOutputStream(os, bufferSize) } - override private[spark] def compressedContinuousOutputStream(s: OutputStream) = { + override private[spark] def compressedContinuousOutputStream(s: OutputStream): OutputStream = { // SPARK-29322: Set "closeFrameOnFlush" to 'true' to let continuous input stream not being // stuck on reading open frame. val os = new ZstdOutputStreamNoFinalizer(s, bufferPool) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 8bac6e736119d..b83f731acb02a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -96,7 +96,7 @@ class CoGroupedRDD[K: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_] => + rdds.map { (rdd: RDD[_]) => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 0930a5c9cfb96..8f03afb5b2664 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -87,7 +87,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } - override val partitioner = Some(part) + override val partitioner: Option[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 27dfdb4daa2c4..6c6246c775e98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -84,7 +84,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( array } - override val partitioner = Some(part) + override val partitioner: Option[Partitioner] = Some(part) override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 678a48948a3c1..86e6d8b3a5aa7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -97,8 +97,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] } } -private[spark] class ZippedPartitionsRDD3 - [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( +private[spark] class ZippedPartitionsRDD3[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], @@ -123,8 +122,8 @@ private[spark] class ZippedPartitionsRDD3 } } -private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( +private[spark] class ZippedPartitionsRDD4[ + A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala index c7cce05f46f49..61899b469213c 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala @@ -21,6 +21,7 @@ import scala.util.control.NonFatal import org.json4s.{DefaultFormats, Extraction, Formats, JValue} import org.json4s.jackson.JsonMethods._ +import org.json4s.jvalue2extractable import org.apache.spark.SparkException import org.apache.spark.annotation.Evolving diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala index 78c45cdc75418..c8a1fae5e8232 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods._ +import org.json4s.jvalue2extractable import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 63b784c47d15a..b56b82a8dd3b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -573,7 +573,7 @@ private[spark] class TaskSchedulerImpl( var globalMinLocality: Option[TaskLocality] = None for (currentMaxLocality <- taskSet.myLocalityLevels) { var launchedTaskAtCurrentMaxLocality = false - do { + while ({ val (noDelayScheduleReject, minLocality) = resourceOfferSingleTaskSet( taskSet, currentMaxLocality, shuffledOffers, availableCpus, availableResources, tasks) @@ -581,7 +581,8 @@ private[spark] class TaskSchedulerImpl( launchedAnyTask |= launchedTaskAtCurrentMaxLocality noDelaySchedulingRejects &= noDelayScheduleReject globalMinLocality = minTaskLocality(globalMinLocality, minLocality) - } while (launchedTaskAtCurrentMaxLocality) + launchedTaskAtCurrentMaxLocality + }) () } if (!legacyLocalityWaitReset) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 6902fb6d236de..edb071c2683b9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -81,7 +81,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = + override val shuffleBlockResolver: IndexShuffleBlockResolver = new IndexShuffleBlockResolver(conf, taskIdMapsForShuffle = taskIdMapsForShuffle) /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 6ae1dce57f31c..fbc88164139aa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -85,8 +85,8 @@ class ExecutorStageSummary private[spark]( val diskBytesSpilled : Long, @deprecated("use isExcludedForStage instead", "3.1.0") val isBlacklistedForStage: Boolean, - @JsonSerialize(using = classOf[ExecutorMetricsJsonSerializer]) - @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) + @JsonSerialize(`using` = classOf[ExecutorMetricsJsonSerializer]) + @JsonDeserialize(`using` = classOf[ExecutorMetricsJsonDeserializer]) val peakMemoryMetrics: Option[ExecutorMetrics], val isExcludedForStage: Boolean) @@ -125,8 +125,8 @@ class ExecutorSummary private[spark]( val memoryMetrics: Option[MemoryMetrics], @deprecated("use excludedInStages instead", "3.1.0") val blacklistedInStages: Set[Int], - @JsonSerialize(using = classOf[ExecutorMetricsJsonSerializer]) - @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) + @JsonSerialize(`using` = classOf[ExecutorMetricsJsonSerializer]) + @JsonDeserialize(`using` = classOf[ExecutorMetricsJsonDeserializer]) val peakMemoryMetrics: Option[ExecutorMetrics], val attributes: Map[String, String], val resources: Map[String, ResourceInformation], @@ -165,9 +165,9 @@ private[spark] class ExecutorMetricsJsonSerializer if (metrics.isEmpty) { jsonGenerator.writeNull() } else { - metrics.foreach { m: ExecutorMetrics => + metrics.foreach { metrics => val metricsMap = ExecutorMetricType.metricToOffset.map { case (metric, _) => - metric -> m.getMetricValue(metric) + metric -> metrics.getMetricValue(metric) } jsonGenerator.writeObject(metricsMap) } @@ -310,8 +310,8 @@ class StageData private[spark]( val speculationSummary: Option[SpeculationStageSummary], val killedTasksSummary: Map[String, Int], val resourceProfileId: Int, - @JsonSerialize(using = classOf[ExecutorMetricsJsonSerializer]) - @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) + @JsonSerialize(`using` = classOf[ExecutorMetricsJsonSerializer]) + @JsonDeserialize(`using` = classOf[ExecutorMetricsJsonDeserializer]) val peakExecutorMetrics: Option[ExecutorMetrics], val taskMetricsDistributions: Option[TaskMetricDistributions], val executorMetricsDistributions: Option[ExecutorMetricsDistributions], @@ -448,11 +448,11 @@ class ExecutorMetricsDistributions private[spark]( val shuffleWriteRecords: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], - @JsonSerialize(using = classOf[ExecutorPeakMetricsDistributionsJsonSerializer]) + @JsonSerialize(`using` = classOf[ExecutorPeakMetricsDistributionsJsonSerializer]) val peakMemoryMetrics: ExecutorPeakMetricsDistributions ) -@JsonSerialize(using = classOf[ExecutorPeakMetricsDistributionsJsonSerializer]) +@JsonSerialize(`using` = classOf[ExecutorPeakMetricsDistributionsJsonSerializer]) class ExecutorPeakMetricsDistributions private[spark]( val quantiles: IndexedSeq[Double], val executorMetrics: IndexedSeq[ExecutorMetrics]) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index fa3aee0103a99..3aea787a4beb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -585,7 +585,7 @@ class BlockManagerMasterEndpoint( private def removeBlockFromWorkers(blockId: BlockId): Unit = { val locations = blockLocations.get(blockId) if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => + locations.foreach { blockManagerId => val blockManager = blockManagerInfo.get(blockManagerId) blockManager.foreach { bm => // Remove the block from the BlockManager. @@ -602,7 +602,7 @@ class BlockManagerMasterEndpoint( // Return a map from the block manager id to max memory and remaining memory. private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => + blockManagerInfo.map { case (blockManagerId, info) => (blockManagerId, (info.maxMem, info.remainingMem)) }.toMap } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 57f6901a7a735..dd4490c8206a2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1014,39 +1014,44 @@ final class ShuffleBlockFetcherIterator( // a SuccessFetchResult or a FailureFetchResult. result = null - case PushMergedLocalMetaFetchResult( - shuffleId, shuffleMergeId, reduceId, bitmaps, localDirs) => - // Fetch push-merged-local shuffle block data as multiple shuffle chunks - val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) - try { - val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, - localDirs) - // Since the request for local block meta completed successfully, numBlocksToFetch - // is decremented. - numBlocksToFetch -= 1 - // Update total number of blocks to fetch, reflecting the multiple local shuffle - // chunks. - numBlocksToFetch += bufs.size - bufs.zipWithIndex.foreach { case (buf, chunkId) => - buf.retain() - val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, - chunkId) - pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) - results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, - pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf, - isNetworkReqDone = false)) - } - } catch { - case e: Exception => - // If we see an exception with reading push-merged-local index file, we fallback - // to fetch the original blocks. We do not report block fetch failure - // and will continue with the remaining local block read. - logWarning("Error occurred while reading push-merged-local index, " + - "prepare to fetch the original blocks", e) - pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( - shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId) + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs + ) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, + localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { case (buf, chunkId) => + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, + chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf, + isNetworkReqDone = false)) } - result = null + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning("Error occurred while reading push-merged-local index, " + + "prepare to fetch the original blocks", e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null case PushMergedRemoteMetaFetchResult( shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps, address) => diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index c409ee37a06a5..cd66be48af56b 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -200,7 +200,7 @@ private[spark] object StorageUtils extends Logging { val unsafeField = classOf[Unsafe].getDeclaredField("theUnsafe") unsafeField.setAccessible(true) val unsafe = unsafeField.get(null).asInstanceOf[Unsafe] - buffer: ByteBuffer => unsafe.invokeCleaner(buffer) + (buffer: ByteBuffer) => unsafe.invokeCleaner(buffer) } /** diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index b72fcb53bdb62..b512f1e33524f 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -42,7 +42,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va */ def getQuantiles(probabilities: Iterable[Double] = defaultProbabilities) : IndexedSeq[Double] = { - probabilities.toIndexedSeq.map { p: Double => data(closestIndex(p)) } + probabilities.toIndexedSeq.map(probabilitie => data(closestIndex(probabilitie))) } private def closestIndex(p: Double) = { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 88fe64859a214..bab94d55b7b4b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -276,9 +276,10 @@ object SizeEstimator extends Logging { var size = 0L for (i <- 0 until ARRAY_SAMPLE_SIZE) { var index = 0 - do { + while ({ index = rand.nextInt(length) - } while (drawn.contains(index)) + drawn.contains(index) + }) () drawn.add(index) val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] if (obj != null) { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index e9d14f904db45..470538ecb3231 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -23,6 +23,7 @@ import java.util.concurrent.locks.ReentrantLock import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.duration.durationToPair import scala.util.control.NonFatal import com.google.common.util.concurrent.ThreadFactoryBuilder diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ea9b742fb2e1b..9b084b2d82d19 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1476,7 +1476,7 @@ private[spark] object Utils var insideSpark = true val callStack = new ArrayBuffer[String]() :+ "" - Thread.currentThread.getStackTrace().foreach { ste: StackTraceElement => + Thread.currentThread.getStackTrace().foreach { ste => // When running under some profilers, the current stack trace might contain some bogus // frames. This is intended to ensure that we don't crash in these situations by // ignoring any frames that we can't examine. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 6927c119a91c5..311a162e502e4 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -126,11 +126,12 @@ private[spark] class RollingFileAppender( // the right pattern such that name collisions do not occur. var i = 0 var altRolloverFile: File = null - do { + while ({ altRolloverFile = new File(activeFile.getParent, s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile i += 1 - } while (i < 10000 && rolloverFileExist(altRolloverFile)) + i < 10000 && rolloverFileExist(altRolloverFile) + }) () logWarning(log"Rollover file ${MDC(FILE_NAME, rolloverFile)} already exists, " + log"rolled over ${MDC(FILE_NAME2, activeFile)} " + diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 7a39ba4ab382b..b8fd1a3832449 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -271,7 +271,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS override def sparkContext: SparkContext = sc - runTest("basic checkpointing") { reliableCheckpoint: Boolean => + runTest("basic checkpointing") { (reliableCheckpoint: Boolean) => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) checkpoint(flatMappedRDD, reliableCheckpoint) @@ -281,7 +281,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(flatMappedRDD.collect() === result) } - runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { (_: Boolean) => def testPartitionerCheckpointing( partitioner: Partitioner, @@ -324,7 +324,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) } - runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => + runTest("RDDs with one-to-one dependencies") { (reliableCheckpoint: Boolean) => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) testRDD(_.filter(_ % 2 == 0), reliableCheckpoint) @@ -337,7 +337,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS testRDD(_.pipe(Seq("cat")), reliableCheckpoint) } - runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean => + runTest("ParallelCollectionRDD") { (reliableCheckpoint: Boolean) => val parCollection = sc.makeRDD(1 to 4, 2) val numPartitions = parCollection.partitions.length checkpoint(parCollection, reliableCheckpoint) @@ -353,7 +353,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(parCollection.collect() === result) } - runTest("BlockRDD") { reliableCheckpoint: Boolean => + runTest("BlockRDD") { (reliableCheckpoint: Boolean) => val blockId = TestBlockId("id") val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) @@ -370,20 +370,20 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(blockRDD.collect() === result) } - runTest("ShuffleRDD") { reliableCheckpoint: Boolean => + runTest("ShuffleRDD") { (reliableCheckpoint: Boolean) => testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) }, reliableCheckpoint) } - runTest("UnionRDD") { reliableCheckpoint: Boolean => + runTest("UnionRDD") { (reliableCheckpoint: Boolean) => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(_.union(otherRDD), reliableCheckpoint) testRDDPartitions(_.union(otherRDD), reliableCheckpoint) } - runTest("CartesianRDD") { reliableCheckpoint: Boolean => + runTest("CartesianRDD") { (reliableCheckpoint: Boolean) => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) @@ -406,7 +406,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS ) } - runTest("CoalescedRDD") { reliableCheckpoint: Boolean => + runTest("CoalescedRDD") { (reliableCheckpoint: Boolean) => testRDD(_.coalesce(2), reliableCheckpoint) testRDDPartitions(_.coalesce(2), reliableCheckpoint) @@ -428,7 +428,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS ) } - runTest("CoGroupedRDD") { reliableCheckpoint: Boolean => + runTest("CoGroupedRDD") { (reliableCheckpoint: Boolean) => val longLineageRDD1 = generateFatPairRDD() // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD @@ -446,7 +446,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS }, reliableCheckpoint, seqCollectFunc) } - runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean => + runTest("ZippedPartitionsRDD") { (reliableCheckpoint: Boolean) => testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) @@ -471,7 +471,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS ) } - runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean => + runTest("PartitionerAwareUnionRDD") { (reliableCheckpoint: Boolean) => testRDD(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), @@ -505,7 +505,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS ) } - runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean => + runTest("CheckpointRDD with zero partitions") { (reliableCheckpoint: Boolean) => val rdd = new BlockRDD[Int](sc, Array.empty[BlockId]) assert(rdd.partitions.length === 0) assert(rdd.isCheckpointed === false) @@ -519,7 +519,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(rdd.partitions.length === 0) } - runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => + runTest("checkpointAllMarkedAncestors") { (reliableCheckpoint: Boolean) => testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterSuite.scala index 5d40a0610eb6c..c3b1a7bf5c2a7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterSuite.scala @@ -61,7 +61,7 @@ class BasicEventFilterSuite extends SparkFunSuite { val jobEndEventForJob1 = SparkListenerJobEnd(1, 0, JobSucceeded) val stageSubmittedEventsForJob1 = SparkListenerStageSubmitted(stage1) val stageCompletedEventsForJob1 = SparkListenerStageCompleted(stage1) - val unpersistRDDEventsForJob1 = (1 to 2).map(SparkListenerUnpersistRDD) + val unpersistRDDEventsForJob1 = (1 to 2).map(SparkListenerUnpersistRDD(_)) // job events for finished job should be rejected assert(Some(false) === acceptFn(jobStartEventForJob1)) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala index 100145a2f4833..42a288b60af25 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala @@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletResponse import org.json4s.DefaultFormats import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods.parse +import org.json4s.jvalue2extractable import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 10092f416f9e1..81d5f00494fd0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,7 +29,7 @@ import jakarta.servlet._ import jakarta.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.json4s.JsonAST._ +import org.json4s._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ import org.mockito.Mockito._ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index aecb8b99d0e31..7afd1fb64b4f1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -806,7 +806,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) + sc.runJob(sc.parallelize(1 to 10, 2), (iter: Iterator[Int]) => iter.size, Seq(0, 1, 2)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 6dc4d4da7bfc1..e531b3b5b0270 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -357,7 +357,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // just to make sure some of the tasks and their deserialization take a noticeable // amount of time val slowDeserializable = new SlowDeserializable - val w = { i: Int => + val w = { (i: Int) => if (i == 0) { Thread.sleep(100) slowDeserializable.use() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 938465ab53265..ff309916c3e97 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -318,7 +318,7 @@ private object TestUserClosuresActuallyCleaned { } def testRunJob2(sc: SparkContext): Unit = { val rdd = sc.parallelize(1 to 10, 10) - sc.runJob(rdd, { _: Iterator[Int] => return; 1 } ) + sc.runJob(rdd, { (_: Iterator[Int]) => return; 1 } ) } def testRunApproximateJob(sc: SparkContext): Unit = { val rdd = sc.parallelize(1 to 10, 10) diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index e3a57a70e2d2a..61dbd3456b9f8 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -157,17 +157,19 @@ class SorterSuite extends SparkFunSuite { var i: Int = 0 sum = 0 val amountOfZeros = arrayToSort.length - runLengths.length - do { + while ({ sum += arrayToSort(i) i += 1 - } while (i < amountOfZeros) + i < amountOfZeros + }) () assert(sum === 0) val sizeOfArrayToSort = arrayToSort.length - do { + while ({ sum += arrayToSort(i) i += 1 - } while (i < sizeOfArrayToSort) + i < sizeOfArrayToSort + }) () assert(sum === runLengths.length) */ } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 11f730c63ca8d..6cd37799dc96b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -60,13 +60,14 @@ object SparkTC { // This join is iterated until a fixed point is reached. var oldCount = 0L var nextCount = tc.count() - do { + while ({ oldCount = nextCount // Perform the join, obtaining an RDD of (y, (z, x)) pairs, // then project the result to obtain the new (x, z) paths. tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache() nextCount = tc.count() - } while (nextCount != oldCount) + nextCount != oldCount + }) () println(s"TC has ${tc.count()} edges.") spark.stop() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index cf4c8ca2a9c42..317eead3d44af 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -30,8 +30,8 @@ import org.apache.spark.util.collection.BitSet * implicit evidence of membership in the `VertexPartitionBaseOpsConstructor` typeclass (for * example, `VertexPartition.VertexPartitionOpsConstructor`). */ -private[graphx] abstract class VertexPartitionBaseOps - [VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor] +private[graphx] abstract class VertexPartitionBaseOps[ + VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor] (self: Self[VD]) extends Serializable with Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 3393ea06ff246..e623f1142bd76 100755 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -51,7 +51,7 @@ object StronglyConnectedComponents { var iter = 0 while (sccWorkGraph.numVertices > 0 && iter < numIter) { iter += 1 - do { + while ({ numVertices = sccWorkGraph.numVertices sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) @@ -77,7 +77,8 @@ object StronglyConnectedComponents { // only keep vertices that are not final sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache() - } while (sccWorkGraph.numVertices < numVertices) + sccWorkGraph.numVertices < numVertices + }) () // if iter < numIter at this point sccGraph that is returned // will not be recomputed and pregel executions are pointless diff --git a/hadoop-cloud/src/test/scala/org/apache/spark/internal/io/cloud/AbortableStreamBasedCheckpointFileManagerSuite.scala b/hadoop-cloud/src/test/scala/org/apache/spark/internal/io/cloud/AbortableStreamBasedCheckpointFileManagerSuite.scala index 0dbc650fc8c73..6cfff663510ca 100644 --- a/hadoop-cloud/src/test/scala/org/apache/spark/internal/io/cloud/AbortableStreamBasedCheckpointFileManagerSuite.scala +++ b/hadoop-cloud/src/test/scala/org/apache/spark/internal/io/cloud/AbortableStreamBasedCheckpointFileManagerSuite.scala @@ -34,7 +34,7 @@ class AbortableStreamBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests with Logging { override def withTempHadoopPath(p: Path => Unit): Unit = { - withTempDir { f: File => + withTempDir { (f: File) => val basePath = new Path(AbortableFileSystem.ABORTABLE_FS_SCHEME, null, f.getAbsolutePath) p(basePath) } diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index ad8869f8a81f7..105f236e5dda5 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -360,8 +360,8 @@ class DenseMatrix @Since("2.0.0") ( override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), - isTransposed) + private[spark] def map(f: Double => Double): Matrix = + new DenseMatrix(numRows, numCols, values.map(f), isTransposed) private[ml] def update(f: Double => Double): DenseMatrix = { val len = values.length @@ -700,7 +700,7 @@ class SparseMatrix @Since("2.0.0") ( new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } - private[spark] def map(f: Double => Double) = + private[spark] def map(f: Double => Double): Matrix = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[ml] def update(f: Double => Double): SparseMatrix = { @@ -1230,10 +1230,10 @@ object Matrices { numCols += mat.numCols } if (!hasSparse) { - new DenseMatrix(numRows, numCols, matrices.flatMap { m: Matrix => m.toArray }) + new DenseMatrix(numRows, numCols, matrices.flatMap { (m: Matrix) => m.toArray }) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat: Matrix => + val entries: Array[(Int, Int, Double)] = matrices.flatMap { (mat: Matrix) => val nCols = mat.numCols mat match { case spMat: SparseMatrix => @@ -1302,7 +1302,7 @@ object Matrices { new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat: Matrix => + val entries: Array[(Int, Int, Double)] = matrices.flatMap { (mat: Matrix) => val nRows = mat.numRows mat match { case spMat: SparseMatrix => diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 83b77510602b2..0ea321fab48fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -201,7 +201,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, protected def transformImpl(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) - val predictUDF = udf { features: Any => + val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))), diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index c31a99dd4fd3f..889f6febbe40d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.ann import java.util.Random -import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} +import breeze.linalg.{`*`, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 7883a0dea54f1..f83c152ebc290 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -128,7 +128,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur var outputData = dataset var numColsOutput = 0 if (getRawPredictionCol != "") { - val predictRawUDF = udf { features: Any => + val predictRawUDF = udf { (features: Any) => predictRaw(features.asInstanceOf[FeaturesType]) } outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)), @@ -139,7 +139,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur val predCol = if (getRawPredictionCol != "") { udf(raw2prediction _).apply(col(getRawPredictionCol)) } else { - val predictUDF = udf { features: Any => + val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } predictUDF(col(getFeaturesCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 2c9f518c772c4..ac9345a3d7295 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, JObject} +import org.json4s._ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since @@ -213,7 +213,7 @@ class DecisionTreeClassificationModel private[ml] ( val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => predictLeaf(features) } + val leafUDF = udf { (features: Vector) => predictLeaf(features) } outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), outputSchema($(leafCol)).metadata) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 9f2c2c85115ba..f7f102232a812 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.classification +import org.json4s._ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ @@ -302,7 +303,7 @@ class GBTClassificationModel private[ml]( val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => predictLeaf(features) } + val leafUDF = udf { (features: Vector) => predictLeaf(features) } outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), outputSchema($(leafCol)).metadata) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 4b0f8c311c3d0..18a6b672ec264 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -543,13 +543,13 @@ class NaiveBayesModel private[ml] ( @transient private lazy val predictRawFunc = { $(modelType) match { case Multinomial => - features: Vector => multinomialCalculation(features) + (features: Vector) => multinomialCalculation(features) case Complement => - features: Vector => complementCalculation(features) + (features: Vector) => complementCalculation(features) case Bernoulli => - features: Vector => bernoulliCalculation(features) + (features: Vector) => bernoulliCalculation(features) case Gaussian => - features: Vector => gaussianCalculation(features) + (features: Vector) => gaussianCalculation(features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 61fab02cb4518..08dac129d5d23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -118,7 +118,7 @@ abstract class ProbabilisticClassificationModel[ var outputData = dataset var numColsOutput = 0 if ($(rawPredictionCol).nonEmpty) { - val predictRawUDF = udf { features: Any => + val predictRawUDF = udf { (features: Any) => predictRaw(features.asInstanceOf[FeaturesType]) } outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)), @@ -129,7 +129,7 @@ abstract class ProbabilisticClassificationModel[ val probCol = if ($(rawPredictionCol).nonEmpty) { udf(raw2probability _).apply(col($(rawPredictionCol))) } else { - val probabilityUDF = udf { features: Any => + val probabilityUDF = udf { (features: Any) => predictProbability(features.asInstanceOf[FeaturesType]) } probabilityUDF(col($(featuresCol))) @@ -144,7 +144,7 @@ abstract class ProbabilisticClassificationModel[ } else if ($(probabilityCol).nonEmpty) { udf(probability2prediction _).apply(col($(probabilityCol))) } else { - val predictUDF = udf { features: Any => + val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } predictUDF(col($(featuresCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 24dd7095513ac..a5dfbd12c8043 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.classification +import org.json4s._ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ @@ -319,7 +320,7 @@ class RandomForestClassificationModel private[ml] ( val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => predictLeaf(features) } + val leafUDF = udf { (features: Vector) => predictLeaf(features) } outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), outputSchema($(leafCol)).metadata) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 3ea1c8594e1fc..bef8a45b5aba2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -484,7 +484,7 @@ abstract class LDAModel private[ml] ( val k = oldModel.k val gammaSeed = oldModel.seed - vector: Vector => + (vector: Vector) => if (vector.numNonzeros == 0) { Vectors.zeros(k) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index 98fbe471f2977..c47d7e2f65615 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -390,7 +390,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext) val squaredNormUDF = udf { - features: Vector => math.pow(Vectors.norm(features, 2.0), 2.0) + (features: Vector) => math.pow(Vectors.norm(features, 2.0), 2.0) } val dfWithSquaredNorm = dataset.withColumn("squaredNorm", squaredNormUDF(col(featuresCol))) @@ -592,7 +592,7 @@ private[evaluation] object CosineSilhouette extends Silhouette { featuresCol: String, weightCol: String): Double = { val normalizeFeatureUDF = udf { - features: Vector => { + (features: Vector) => { val norm = Vectors.norm(features, 2.0) BLAS.scal(1.0 / norm, features) features diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 8123438fd8878..129871d906952 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -124,7 +124,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) .otherwise(lit(0.0)) case _: VectorUDT if td >= 0 => - udf { vector: Vector => + udf { (vector: Vector) => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] vector.foreachNonZero { (index, value) => @@ -144,7 +144,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) logWarning(log"Binarization operations on sparse dataset with negative threshold " + log"${MDC(LogKeys.THRESHOLD, td)} will build a dense output, so take care when " + log"applying to sparse input.") - udf { vector: Vector => + udf { (vector: Vector) => val values = Array.fill(vector.size)(1.0) var nnz = vector.size vector.foreachNonZero { (index, value) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 34465248f20df..0cbde628d778c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -317,7 +317,7 @@ class CountVectorizerModel( // rather than once-per-row: val minTf = $(minTF) val isBinary = $(binary) - val vectorizer = udf { document: Seq[String] => + val vectorizer = udf { (document: Seq[String]) => val termCounts = new OpenHashMap[Int, Double] var tokenCount = 0L document.foreach { term => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 3b328f2fd8cee..9271ffe47b5e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -58,7 +58,7 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri val elemScaler = new OldElementwiseProduct(OldVectors.fromML($(scalingVec))) val vectorSize = $(scalingVec).size - vector: Vector => { + (vector: Vector) => { require(vector.size == vectorSize, s"vector sizes do not match: Expected $vectorSize but found ${vector.size}") vector match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index 4a12d77ed8400..850ce4f3d3db0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -156,7 +156,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme } } - val hashFeatures = udf { row: Row => + val hashFeatures = udf { (row: Row) => val map = new OpenHashMap[Int, Double]() var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index dab0a6494fdb9..3721db48ddaf5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -93,7 +93,7 @@ class HashingTF @Since("3.0.0") private[ml] ( val n = $(numFeatures) val updateFunc = if ($(binary)) (v: Double) => 1.0 else (v: Double) => v + 1.0 - val hashUDF = udf { terms: Seq[_] => + val hashUDF = udf { (terms: Seq[_]) => val map = new OpenHashMap[Int, Double]() terms.foreach { term => map.changeValue(indexOf(term), 1.0, updateFunc) } Vectors.sparse(n, map.toSeq) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index c2b7ff7b00a3c..a3a1edeca8517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -136,7 +136,7 @@ class IDFModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) - val func = { vector: Vector => + val func = { (vector: Vector) => vector match { case SparseVector(size, indices, values) => val (newIndices, newValues) = feature.IDFModel.transformSparse(idfModel.idf, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 3311231e6d830..3cc4cd6b29bf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -74,7 +74,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext val featureEncoders = getFeatureEncoders(inputFeatures.toImmutableArraySeq) val featureAttrs = getFeatureAttrs(inputFeatures.toImmutableArraySeq) - def interactFunc = udf { row: Row => + def interactFunc = udf { (row: Row) => var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] var size = 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c54e64f97953e..a566b8f22e38f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -191,7 +191,7 @@ class MinMaxScalerModel private[ml] ( if (range != 0) scale / range else 0.0 } - val transformer = udf { vector: Vector => + val transformer = udf { (vector: Vector) => require(vector.size == numFeatures, s"Number of features must be $numFeatures but got ${vector.size}") // 0 in sparse vector will probably be rescaled to non-zero diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 6b61e761f5894..7aec934f3571a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -149,7 +149,7 @@ class PCAModel private[ml] ( val outputSchema = transformSchema(dataset.schema, logging = true) val transposed = pc.transpose - val transformer = udf { vector: Vector => transposed.multiply(vector) } + val transformer = udf { (vector: Vector) => transposed.multiply(vector) } dataset.withColumn($(outputCol), transformer(col($(inputCol))), outputSchema($(outputCol)).metadata) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala index dde1068c5b924..4eea97a9bbd6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala @@ -312,7 +312,7 @@ private[feature] object SelectorModel { outputCol: String, featuresCol: String): DataFrame = { val newSize = selectedFeatures.length - val func = { vector: Vector => + val func = { (vector: Vector) => vector match { case SparseVector(_, indices, values) => val (newIndices, newValues) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c1ac1fdbba7d8..3e4c79f2131b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -291,7 +291,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { withScale: Boolean): Vector => Vector = { (withShift, withScale) match { case (true, true) => - vector: Vector => + (vector: Vector) => val values = vector match { case d: DenseVector => d.values.clone() case v: Vector => v.toArray @@ -300,7 +300,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { Vectors.dense(newValues) case (true, false) => - vector: Vector => + (vector: Vector) => val values = vector match { case d: DenseVector => d.values.clone() case v: Vector => v.toArray @@ -309,7 +309,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { Vectors.dense(newValues) case (false, true) => - vector: Vector => + (vector: Vector) => vector match { case DenseVector(values) => val newValues = transformDenseWithScale(scale, values.clone()) @@ -322,7 +322,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { } case (false, false) => - vector: Vector => vector + (vector: Vector) => vector } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index cb9d8b32f0064..0af02ef27f5d9 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -139,7 +139,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet - udf { terms: Seq[String] => + udf { (terms: Seq[String]) => terms.filter(s => !stopWordsSet.contains(s)) } } else { @@ -148,7 +148,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s // scalastyle:on caselocale val lowerStopWords = $(stopWords).map(toLower(_)).toSet - udf { terms: Seq[String] => + udf { (terms: Seq[String]) => terms.filter(s => !lowerStopWords.contains(toLower(s))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 06a88e9b1c499..5102e19b72dea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -356,7 +356,7 @@ class StringIndexerModel ( // expression, however, lookup for a key in a map is not efficient in SparkSQL now. // See `ElementAt` and `GetMapValue` expressions. If SQL's map lookup is improved, // we can consider to change this. - val filter = udf { label: String => + val filter = udf { (label: String) => labelToIndex.contains(label) } filter(dataset(inputColName)) @@ -369,7 +369,7 @@ class StringIndexerModel ( private def getIndexer(labels: Seq[String], labelToIndex: OpenHashMap[String, Double]) = { val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) - udf { label: String => + udf { (label: String) => if (label == null) { if (keepInvalid) { labels.length @@ -590,7 +590,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String) } else { $(labels) } - val indexer = udf { index: Double => + val indexer = udf { (index: Double) => val idx = index.toInt if (0 <= idx && idx < values.length) { values(idx) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 39ffaf32a1f36..852234130363e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -92,7 +92,7 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol } else if (isSet(outputCols)) { $(outputCols) } else { - inputFeatures.map{field: String => s"${field}_indexed"} + inputFeatures.map(field => s"${field}_indexed") } private[feature] def validateSchema(schema: StructType, fitting: Boolean): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 704166d9b6575..8c43df58a0375 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -384,7 +384,7 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect outputCol: String, featuresCol: String): DataFrame = { val newSize = selectedFeatures.length - val func = { vector: Vector => + val func = { (vector: Vector) => vector match { case SparseVector(_, indices, values) => val (newIndices, newValues) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 831a8a33afecb..fc6206a21fdd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -141,7 +141,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } val keepInvalid = $(handleInvalid) == VectorAssembler.KEEP_INVALID // Data transformation. - val assembleFunc = udf { r: Row => + val assembleFunc = udf { (r: Row) => VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = inputColsWithField.map { case (c, field) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 091e209227827..3f9d1cb9aab7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -437,7 +437,7 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) - val transformUDF = udf { vector: Vector => transformFunc(vector) } + val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) val ds = dataset.withColumn($(outputCol), newCol, newField.metadata) if (getHandleInvalid == VectorIndexer.SKIP_INVALID) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala index 4abb607733e35..09387b75aedf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -51,7 +51,8 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) this, "size", "Size of vectors in column.", - {s: Int => s >= 0}) + (s: Int) => s >= 0 + ) /** group getParam */ @Since("2.3.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 58a44a41f0e84..8dbf5d1e2c6b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -120,7 +120,7 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } val sorted = selectedIndices.length > 1 && selectedIndices.sliding(2).forall(t => t(1) > t(0)) - val slicer = udf { vec: Vector => + val slicer = udf { (vec: Vector) => vec match { case dv: DenseVector => Vectors.dense(selectedIndices.map(dv.apply)) case sv: SparseVector => sv.slice(selectedIndices, sorted) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 7d6765b231b5c..6e346f31c7c8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -294,7 +294,7 @@ class Word2VecModel private[ml] ( val bcModel = dataset.sparkSession.sparkContext.broadcast(this.wordVectors) val size = $(vectorSize) val emptyVec = Vectors.sparse(size, Array.emptyIntArray, Array.emptyDoubleArray) - val transformer = udf { sentence: Seq[String] => + val transformer = udf { (sentence: Seq[String]) => if (sentence.isEmpty) { emptyVec } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 0b75753695fd5..5fbf5177552d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.fpm import scala.reflect.ClassTag import org.apache.hadoop.fs.Path +import org.json4s._ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala index 3882a97e0bb4d..46d6c4a34a343 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.ml.linalg -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala index 3e1dff58764d6..8837a6ab01769 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.linalg -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala index dd6e91e891d66..3b5d6848f22ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -47,7 +47,7 @@ private[r] class GaussianMixtureWrapper private ( lazy val sigma: Array[Double] = gmm.gaussians.flatMap(_.cov.toArray) - lazy val vectorToArray = udf { probability: Vector => probability.toArray } + lazy val vectorToArray = udf { (probability: Vector) => probability.toArray } lazy val posterior: DataFrame = gmm.summary.probability .withColumn("posterior", vectorToArray(col(gmm.summary.probabilityCol))) .drop(gmm.summary.probabilityCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala index cfcd4a85ab27b..55a3f174d8cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala @@ -56,7 +56,7 @@ private[r] class LDAWrapper private ( new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1)) def transform(data: Dataset[_]): DataFrame = { - val vec2ary = udf { vec: Vector => vec.toArray } + val vec2ary = udf { (vec: Vector) => vec.toArray } val outputCol = lda.getTopicDistributionCol val tempCol = s"${Identifiable.randomUID(outputCol)}" val preprocessed = preprocessor.transform(data) @@ -74,7 +74,7 @@ private[r] class LDAWrapper private ( if (vocabulary.isEmpty || vocabulary.length < vocabSize) { topicIndices } else { - val index2term = udf { indices: mutable.ArraySeq[Int] => indices.map(i => vocabulary(i)) } + val index2term = udf { (indices: mutable.ArraySeq[Int]) => indices.map(i => vocabulary(i)) } topicIndices .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights")) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 3a7539e0937fe..d89e7e8dc0e82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.r import org.apache.hadoop.fs.Path +import org.json4s._ import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 95c47531720d5..2bb05dee9c165 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -28,6 +28,7 @@ import scala.util.hashing.byteswap64 import com.google.common.collect.{Ordering => GuavaOrdering} import org.apache.hadoop.fs.Path +import org.json4s._ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4f38d87574132..49df533710a42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, JObject} +import org.json4s._ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since @@ -219,21 +219,21 @@ class DecisionTreeRegressionModel private[ml] ( var predictionColumns = Seq.empty[Column] if ($(predictionCol).nonEmpty) { - val predictUDF = udf { features: Vector => predict(features) } + val predictUDF = udf { (features: Vector) => predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) .as($(predictionCol), outputSchema($(predictionCol)).metadata) } if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { - val predictVarianceUDF = udf { features: Vector => predictVariance(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } predictionColNames :+= $(varianceCol) predictionColumns :+= predictVarianceUDF(col($(featuresCol))) .as($(varianceCol), outputSchema($(varianceCol)).metadata) } if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => predictLeaf(features) } + val leafUDF = udf { (features: Vector) => predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) .as($(leafCol), outputSchema($(leafCol)).metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index 5cc93e14fa3d5..f1bf8fdf8adbf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.regression import scala.util.Random import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV} +import breeze.linalg.InjectNumericOps import breeze.numerics.{sqrt => brzSqrt} import org.apache.hadoop.fs.Path diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f71eea6c62933..f1300b3fd3b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.json4s.{DefaultFormats, JObject} +import org.json4s._ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since @@ -275,14 +275,14 @@ class GBTRegressionModel private[ml]( val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) if ($(predictionCol).nonEmpty) { - val predictUDF = udf { features: Vector => bcastModel.value.predict(features) } + val predictUDF = udf { (features: Vector) => bcastModel.value.predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) .as($(featuresCol), outputSchema($(featuresCol)).metadata) } if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) } + val leafUDF = udf { (features: Vector) => bcastModel.value.predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) .as($(leafCol), outputSchema($(leafCol)).metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 6d4669ec78af9..7a0411226aff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1333,7 +1333,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } private[regression] lazy val pearsonResiduals: DataFrame = { - val prUDF = udf { mu: Double => family.variance(mu) } + val prUDF = udf { (mu: Double) => family.variance(mu) } predictions.select(label.minus(prediction) .multiply(sqrt(weight)).divide(sqrt(prUDF(prediction))).as("pearsonResiduals")) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index e1bfff068cfe2..d653050bfcbcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -85,7 +85,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) - val extract = udf { v: Vector => v(idx) } + val extract = udf { (v: Vector) => v(idx) } extract(checkNonNanVectors($(featuresCol))) } else { checkNonNanValues($(featuresCol), "Features") @@ -249,10 +249,10 @@ class IsotonicRegressionModel private[ml] ( val outputSchema = transformSchema(dataset.schema, logging = true) val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => - udf { feature: Double => oldModel.predict(feature) } + udf { (feature: Double) => oldModel.predict(feature) } case _: VectorUDT => val idx = $(featureIndex) - udf { features: Vector => oldModel.predict(features(idx)) } + udf { (features: Vector) => oldModel.predict(features(idx)) } } dataset.withColumn($(predictionCol), predict(col($(featuresCol))), outputSchema($(predictionCol)).metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 97d0f54d0eca4..17c98ec69ac06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.regression +import org.json4s._ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ @@ -242,14 +243,14 @@ class RandomForestRegressionModel private[ml] ( val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) if ($(predictionCol).nonEmpty) { - val predictUDF = udf { features: Vector => bcastModel.value.predict(features) } + val predictUDF = udf { (features: Vector) => bcastModel.value.predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) .as($(predictionCol), outputSchema($(predictionCol)).metadata) } if ($(leafCol).nonEmpty) { - val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) } + val leafUDF = udf { (features: Vector) => bcastModel.value.predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) .as($(leafCol), outputSchema($(leafCol)).metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index d023c8990e76d..567a6ed01bf01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.Path -import org.json4s.DefaultFormats +import org.json4s._ import org.apache.spark.annotation.Since import org.apache.spark.internal.{Logging, MDC} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index ebfcac2e4952b..cc9d47bb7d087 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.language.existentials import org.apache.hadoop.fs.Path -import org.json4s.DefaultFormats +import org.json4s._ import org.apache.spark.annotation.Since import org.apache.spark.internal.{Logging, MDC} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index ad7435ce5be76..2938405754bb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s._ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 75262ac4fe06b..e4d1c70896e88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -216,16 +216,18 @@ class LogisticRegressionWithSGD private[mllib] ( private val gradient = new LogisticGradient() private val updater = new SquaredL2Updater() @Since("0.8.0") - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer: GradientDescent = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) override protected val validators = List(DataValidators.binaryLabelValidator) - override protected[mllib] def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel( + weights: Vector, + intercept: Double + ): LogisticRegressionModel = new LogisticRegressionModel(weights, intercept) - } } /** @@ -249,9 +251,9 @@ class LogisticRegressionWithLBFGS this.setFeatureScaling(true) @Since("1.1.0") - override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) + override val optimizer: LBFGS = new LBFGS(new LogisticGradient, new SquaredL2Updater) - override protected val validators = List(multiLabelValidator) + override protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List(multiLabelValidator) private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => if (numOfLinearPredictor > 1) { @@ -276,7 +278,10 @@ class LogisticRegressionWithLBFGS this } - override protected def createModel(weights: Vector, intercept: Double) = { + override protected def createModel( + weights: Vector, + intercept: Double + ): LogisticRegressionModel = { if (numOfLinearPredictor == 1) { new LogisticRegressionModel(weights, intercept) } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 33ce0d7a7cdf1..2bf411c96e4e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -137,12 +137,13 @@ class SVMWithSGD private ( private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() @Since("0.8.0") - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer: GradientDescent = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override protected val validators = List(DataValidators.binaryLabelValidator) + override protected val validators: Seq[RDD[LabeledPoint] => Boolean] = + Seq(DataValidators.binaryLabelValidator) /** * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, @@ -151,9 +152,8 @@ class SVMWithSGD private ( @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double): SVMModel = new SVMModel(weights, intercept) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 40a810a699ac1..6d092d284b64b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 10a81acede0c7..6989c6dee9255 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -18,9 +18,10 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.linalg.InjectNumericOps import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index e4dd53ec9cce9..c9e0fc102e0d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import java.util.Random import breeze.linalg.{all, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.linalg.InjectNumericOps import breeze.numerics.{abs, exp, trigamma} import breeze.stats.distributions.{Gamma, RandBasis} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 100fa13db5180..04c6948b240d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import com.google.common.collect.{Ordering => GuavaOrdering} -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index fc0c8d42579a9..0dacff1069a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -344,8 +344,8 @@ class DenseMatrix @Since("1.3.0") ( @Since("1.4.0") override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), - isTransposed) + private[spark] def map(f: Double => Double): Matrix = + new DenseMatrix(numRows, numCols, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { val len = values.length @@ -665,7 +665,7 @@ class SparseMatrix @Since("1.3.0") ( new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } - private[spark] def map(f: Double => Double) = + private[spark] def map(f: Double => Double): Matrix = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { @@ -1130,7 +1130,7 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat: Matrix => + val entries: Array[(Int, Int, Double)] = matrices.flatMap { (mat: Matrix) => val nCols = mat.numCols mat match { case spMat: SparseMatrix => @@ -1199,7 +1199,7 @@ object Matrices { new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat: Matrix => + val entries: Array[(Int, Int, Double)] = matrices.flatMap { (mat: Matrix) => val nRows = mat.numRows mat match { case spMat: SparseMatrix => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 47209a5c9aa91..66504c881958c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.language.implicitConversions import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.json4s.{DefaultFormats, Formats} +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index f92ac0789c952..97f2d0071f69d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -128,7 +128,7 @@ class IndexedRowMatrix @Since("1.0.0") ( val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt - val blocks = rows.flatMap { ir: IndexedRow => + val blocks = rows.flatMap { (ir: IndexedRow) => val blockRow = ir.index / rowsPerBlock val rowInBlock = ir.index % rowsPerBlock diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 13920b5330e93..5098eb696601b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -92,13 +92,13 @@ class LassoWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new L1Updater() @Since("0.8.0") - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer: GradientDescent = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double): LassoModel = { new LassoModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index bd42d7b220253..ee3f6b5f9d2b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -93,7 +93,7 @@ class LinearRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SimpleUpdater() @Since("0.8.0") - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer: GradientDescent = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 0e2dbe43e45bb..e7d42722627f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s._ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 1f67536f699c8..3efce41c62626 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -92,13 +92,12 @@ class RidgeRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SquaredL2Updater() @Since("0.8.0") - override val optimizer = new GradientDescent(gradient, updater) + override val optimizer: GradientDescent = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double): RidgeRegressionModel = new RidgeRegressionModel(weights, intercept) - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b8fcb1ffcbfe1..408ba76dfbc29 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -355,7 +355,7 @@ object MLUtils extends Logging { // TODO: This implementation has performance issues due to unnecessary serialization. // TODO: It is better (but trickier) if we can cast the old vector type to new type directly. - val convertToML = udf { v: Vector => v.asML } + val convertToML = udf { (v: Vector) => v.asML } val exprs = schema.fields.map { field => val c = field.name if (colSet.contains(c)) { @@ -459,7 +459,7 @@ object MLUtils extends Logging { logWarning("Matrix column conversion has serialization overhead. " + "Please migrate your datasets and workflows to use the spark.ml package.") - val convertToML = udf { v: Matrix => v.asML } + val convertToML = udf { (v: Matrix) => v.asML } val exprs = schema.fields.map { field => val c = field.name if (colSet.contains(c)) { diff --git a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala index 173f3d4f99c17..76a08fdd3c0c9 100644 --- a/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala +++ b/mllib/src/main/scala/org/apache/spark/sql/ml/InternalFunctionRegistration.scala @@ -44,7 +44,7 @@ object InternalFunctionRegistration { FunctionRegistry.internal.createOrReplaceTempFunction(name, builder, "internal") } - private val vectorToArrayUdf = udf { vec: Any => + private val vectorToArrayUdf = udf { (vec: Any) => vec match { case v: Vector => v.toArray case v: OldVector => v.toArray @@ -55,7 +55,7 @@ object InternalFunctionRegistration { } }.asNonNullable() - private val vectorToArrayFloatUdf = udf { vec: Any => + private val vectorToArrayFloatUdf = udf { (vec: Any) => vec match { case v: SparseVector => val data = new Array[Float](v.size) @@ -87,7 +87,7 @@ object InternalFunctionRegistration { throw QueryCompilationErrors.wrongNumArgsError("vector_to_array", "2", exprs.size) } - private val arrayToVectorUdf = udf { array: Seq[Double] => + private val arrayToVectorUdf = udf { (array: Seq[Double]) => Vectors.dense(array.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 57cd99ecced16..c0ec7bc138aaa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -31,7 +31,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import testImplicits._ private def getTestData(labels: Seq[Double]): DataFrame = { - labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() + labels.map { (label: Double) => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() } test("getNumClasses") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala index 7066beb39923f..14c71e24f900b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala @@ -171,7 +171,7 @@ class FMClassifierSuite extends MLTest with DefaultReadWriteTest { // constant threshold scaling is the same as no thresholds fmModel.setThresholds(Array(1.0, 1.0)) testTransformerByGlobalCheckFunc[(Double, Vector)](df, fmModel, "prediction") { - scaledPredictions: Seq[Row] => + (scaledPredictions: Seq[Row]) => assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => scaled.getDouble(0) === base.getDouble(0) }) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 6ce2108b1f7c8..5a0f099abc62e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -155,7 +155,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { - scaledPredictions: Seq[Row] => + (scaledPredictions: Seq[Row]) => assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => scaled.getDouble(0) === base.getDouble(0) }) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 9dc03c147f6d7..e9a27ef7ba363 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -152,7 +152,7 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { - rows: Seq[Row] => + (rows: Seq[Row]) => val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet assert(results === expected, s"Failed for threshold = $threshold") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 124a08db3de60..54f4cb3217d9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -482,7 +482,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, - "prediction") { scaledPredictions: Seq[Row] => + "prediction") { (scaledPredictions: Seq[Row]) => assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => scaled.getDouble(0) === base.getDouble(0) }) @@ -687,7 +687,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { ).toDF() testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), - model, "rawPrediction", "probability") { results: Seq[Row] => + model, "rawPrediction", "probability") { (results: Seq[Row]) => // probabilities are correct when margins have to be adjusted val raw1 = results(0).getAs[Vector](0) val prob1 = results(0).getAs[Vector](1) @@ -2749,7 +2749,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { .map(_.getDouble(0)) for (model <- Seq(model1, model2)) { testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, - "prediction") { rows: Seq[Row] => + "prediction") { (rows: Seq[Row]) => rows.map(_.getDouble(0)).toArray === binaryExpected } } @@ -2764,7 +2764,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { .collect().map(_.getDouble(0)) for (model <- Seq(model3, model4)) { testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, - "prediction") { rows: Seq[Row] => + "prediction") { (rows: Seq[Row]) => rows.map(_.getDouble(0)).toArray === multinomialExpected } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d2a22a03e2a86..293e9f2b5b13c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -206,7 +206,7 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { - rows: Seq[Row] => + (rows: Seq[Row]) => val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 381f4f071c4fd..f2413bce35792 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -203,12 +203,12 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, - "prediction", "label") { predictionAndLabels: Seq[Row] => + "prediction", "label") { (predictionAndLabels: Seq[Row]) => validatePrediction(predictionAndLabels) } testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, - "features", "probability") { featureAndProbabilities: Seq[Row] => + "features", "probability") { (featureAndProbabilities: Seq[Row]) => validateProbabilities(featureAndProbabilities, model, "multinomial") } @@ -290,12 +290,12 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, - "prediction", "label") { predictionAndLabels: Seq[Row] => + "prediction", "label") { (predictionAndLabels: Seq[Row]) => validatePrediction(predictionAndLabels) } testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, - "features", "probability") { featureAndProbabilities: Seq[Row] => + "features", "probability") { (featureAndProbabilities: Seq[Row]) => validateProbabilities(featureAndProbabilities, model, "bernoulli") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c8a748e251392..f6858f4881589 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -44,12 +44,12 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { super.beforeAll() dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) - denseDataset = denseData.map(FeatureData).toDF() + denseDataset = denseData.map(FeatureData(_)).toDF() sparseDataset = denseData.map { point => FeatureData(point.toSparse) }.toDF() - decompositionDataset = decompositionData.map(FeatureData).toDF() - rDataset = rData.map(FeatureData).toDF() + decompositionDataset = decompositionData.map(FeatureData(_)).toDF() + rDataset = rData.map(FeatureData(_)).toDF() } test("gmm fails on high dimensional data") { @@ -314,7 +314,7 @@ object GaussianMixtureSuite extends SparkFunSuite { Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) ) - val decompositionData: Seq[Vector] = Seq.tabulate(25) { i: Int => + val decompositionData: Seq[Vector] = Seq.tabulate(25) { (i: Int) => Vectors.dense(Array.tabulate(50)(i + _.toDouble)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index a7d320e8164b6..383058ac2fdd2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -98,7 +98,7 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest .setSeed(12345) val brpModel = brp.fit(dataset) val unitVectors = brpModel.randUnitVectors - unitVectors.foreach { v: Vector => + unitVectors.foreach { (v: Vector) => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index 6bb3ce224a2e7..00c047efa4a14 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -279,8 +279,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - wrong data type") { val wrong_schema = new StructType( - schema.map{ - field: StructField => if (field.name != "input3") field + schema.map { (field: StructField) => + if (field.name != "input3") field else StructField(field.name, StringType, field.nullable, field.metadata) }.toArray) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 8d5ce6395af5f..f7a4aaed514a3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -158,7 +158,7 @@ class VectorAssemblerSuite val filteredDF = df.filter($"y".isNotNull) - val vectorUDF = udf { vector: Vector => + val vectorUDF = udf { (vector: Vector) => vector.numActives } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 384fcf6ceb859..600bb59cd9558 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -93,13 +93,13 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = densePoints1Seq.map(FeatureData).toDF() - sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() - densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF() - sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF() - densePoints2 = densePoints2Seq.map(FeatureData).toDF() - sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() - badPoints = badPointsSeq.map(FeatureData).toDF() + densePoints1 = densePoints1Seq.map(FeatureData(_)).toDF() + sparsePoints1 = sparsePoints1Seq.map(FeatureData(_)).toDF() + densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData(_)).toDF() + sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData(_)).toDF() + densePoints2 = densePoints2Seq.map(FeatureData(_)).toDF() + sparsePoints2 = sparsePoints2Seq.map(FeatureData(_)).toDF() + badPoints = badPointsSeq.map(FeatureData(_)).toDF() } private def getIndexer: VectorIndexer = @@ -112,7 +112,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { } test("Cannot fit an empty DataFrame") { - val rdd = Array.empty[Vector].map(FeatureData).toSeq.toDF() + val rdd = Array.empty[Vector].map(FeatureData.apply).toSeq.toDF() val vectorIndexer = getIndexer intercept[NoSuchElementException] { vectorIndexer.fit(rdd) @@ -192,7 +192,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) assert(featureAttrs.name === "indexed") assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => + categoricalFeatures.foreach { (feature: Int) => val origValueSet = collectedData.map(_(feature)).toSet val targetValueIndexSet = Range(0, origValueSet.size).toSet val catMap = categoryMaps(feature) @@ -219,7 +219,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { } // Check numerical feature metadata. Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => + .foreach { (feature: Int) => val featureAttr = featureAttrs(feature) featureAttr match { case attr: NumericAttribute => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index d9b454e0d10a2..2afb5de35236c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -94,12 +94,12 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { } test("AFTSurvivalRegression validate input dataset") { - testInvalidRegressionLabels { df: DataFrame => + testInvalidRegressionLabels { (df: DataFrame) => val dfWithCensors = df.withColumn("censor", lit(1.0)) new AFTSurvivalRegression().fit(dfWithCensors) } - testInvalidVectors { df: DataFrame => + testInvalidVectors { (df: DataFrame) => val dfWithCensors = df.withColumn("censor", lit(1.0)) new AFTSurvivalRegression().fit(dfWithCensors) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 18cce169b4ce9..2cc612c551fe9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -75,7 +75,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val group = AttributeGroup.fromStructField(dataframe.schema(vecColName)) assert(group.size === vecSize, s"the vector size obtained from schema should be $vecSize, but got ${group.size}") - val sizeUDF = udf { vector: Vector => vector.size } + val sizeUDF = udf { (vector: Vector) => vector.size } assert(dataframe.select(sizeUDF(col(vecColName))) .as[Int] .collect() @@ -147,7 +147,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => dataframe, transformer, firstResultCol, - otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) } + otherResultCols: _*) { (rows: Seq[Row]) => rows.foreach(checkFunction(_)) } } def testTransformerByGlobalCheckFunc[A : Encoder]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala index 1732469ccf590..982180152c50b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala @@ -47,7 +47,7 @@ class MLTestSuite extends MLTest { } intercept[Exception] { testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { - rows: scala.collection.Seq[Row] => + (rows: scala.collection.Seq[Row]) => assert(rows.map(_.getDouble(1)).max === 1.0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 2ba987b96ef79..e4fc8cd10a78e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -197,7 +197,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) ) - val data2: Array[Vector] = Array.tabulate(25) { i: Int => + val data2: Array[Vector] = Array.tabulate(25) { (i: Int) => Vectors.dense(Array.tabulate(50)(i + _.toDouble)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 6c0f096dc14a6..52b163857d5db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -129,7 +129,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: topDocumentsPerTopic // Compare it with top documents per topic derived from topicDistributions - val topDocsByTopicDistributions = { n: Int => + val topDocsByTopicDistributions = { (n: Int) => Range(0, k).map { topic => val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip (doc.toArray, docWeights.map(_(topic)).toArray) diff --git a/pom.xml b/pom.xml index a161c46c126a4..f8464069e7588 100644 --- a/pom.xml +++ b/pom.xml @@ -2736,6 +2736,7 @@ -explaintypes -release 17 + -Xsource:3-cross -Wconf:any:e -Wconf:cat=deprecation:wv -Wunused:imports diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index bc901b0fe6be2..1e3e7d25f3a57 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -229,6 +229,7 @@ object SparkBuild extends PomBuild { lazy val compilerWarningSettings: Seq[sbt.Def.Setting[_]] = Seq( (Compile / scalacOptions) ++= { Seq( + "-Xsource:3-cross", // replace -Xfatal-warnings with fine-grained configuration, since 2.13.2 // verbose warning on deprecation, error on all others // see `scalac -Wconf:help` for details diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 557bf01cbdbae..1046ee1720233 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -128,17 +128,17 @@ object SparkKubernetesClientFactory extends Logging { object ClientType extends Enumeration { import scala.language.implicitConversions - val Driver = Val(DRIVER_CLIENT_REQUEST_TIMEOUT, DRIVER_CLIENT_CONNECTION_TIMEOUT) - val Submission = Val(SUBMISSION_CLIENT_REQUEST_TIMEOUT, SUBMISSION_CLIENT_CONNECTION_TIMEOUT) + val Driver = ClientTypeVal(DRIVER_CLIENT_REQUEST_TIMEOUT, DRIVER_CLIENT_CONNECTION_TIMEOUT) + val Submission = + ClientTypeVal(SUBMISSION_CLIENT_REQUEST_TIMEOUT, SUBMISSION_CLIENT_CONNECTION_TIMEOUT) - protected case class Val( + protected case class ClientTypeVal( requestTimeoutEntry: ConfigEntry[Int], connectionTimeoutEntry: ConfigEntry[Int]) - extends super.Val { + extends Val { def requestTimeout(conf: SparkConf): Int = conf.get(requestTimeoutEntry) def connectionTimeout(conf: SparkConf): Int = conf.get(connectionTimeoutEntry) } - - implicit def convert(value: Value): Val = value.asInstanceOf[Val] + implicit def convert(value: Value): ClientTypeVal = value.asInstanceOf[ClientTypeVal] } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 09faa2a7fb1b3..62b9862669663 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -208,12 +208,13 @@ private[spark] class KubernetesClusterSchedulerBackend( .withLabelIn(SPARK_EXECUTOR_ID_LABEL, execIds: _*) .resources() .forEach { podResource => - podResource.edit({ p: Pod => + podResource.edit((p: Pod) => new PodBuilder(p).editOrNewMetadata() .addToLabels(label, conf.get(KUBERNETES_EXECUTOR_DECOMMISSION_LABEL_VALUE).getOrElse("")) .endMetadata() - .build()}) + .build() + ) } } } @@ -316,10 +317,12 @@ private[spark] class KubernetesClusterSchedulerBackend( kubernetesClient.pods() .inNamespace(namespace) .withName(x.podName) - .edit({p: Pod => new PodBuilder(p).editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, newId) - .endMetadata() - .build()}) + .edit((p: Pod) => + new PodBuilder(p).editMetadata() + .addToLabels(SPARK_EXECUTOR_ID_LABEL, newId) + .endMetadata() + .build() + ) } } executorService.execute(labelTask) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index db56b39e28aeb..2021a91aadc9e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -250,7 +250,7 @@ case class WhenMatched[T] private[sql] ( * @tparam T * The type of data in the MergeIntoWriter. */ -case class WhenNotMatched[T] private[sql] ( +case class WhenNotMatched[T]( mergeIntoWriter: MergeIntoWriter[T], condition: Option[Column]) { @@ -287,7 +287,7 @@ case class WhenNotMatched[T] private[sql] ( * @tparam T * the type parameter for the MergeIntoWriter. */ -case class WhenNotMatchedBySource[T] private[sql] ( +case class WhenNotMatchedBySource[T]( mergeIntoWriter: MergeIntoWriter[T], condition: Option[Column]) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala index 8dff1ceccfcfe..80274d9a176a3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala @@ -65,7 +65,7 @@ object RebaseDateTime { */ private def rebaseDays(switches: Array[Int], diffs: Array[Int], days: Int): Int = { var i = switches.length - do { i -= 1 } while (i > 0 && days < switches(i)) + while ({ i -= 1; i > 0 && days < switches(i) }) () days + diffs(i) } @@ -274,7 +274,7 @@ object RebaseDateTime { private def rebaseMicros(rebaseInfo: RebaseInfo, micros: Long): Long = { val switches = rebaseInfo.switches var i = switches.length - do { i -= 1 } while (i > 0 && micros < switches(i)) + while ({ i -= 1; i > 0 && micros < switches(i) }) () micros + rebaseInfo.diffs(i) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 15784e9762e35..0d4ae44498776 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -394,11 +394,12 @@ class FractionTimestampFormatter(zoneId: ZoneId) formatted.getChars(0, formatted.length, buf, 0) buf(formatted.length) = '.' var i = totalLen - do { + while ({ i -= 1 buf(i) = ('0' + (nanos % 10)).toChar nanos /= 10 - } while (i > fracOffset) + i > fracOffset + }) () new String(buf) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index f798276d60f7c..fcb1e7298a158 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -46,8 +46,8 @@ import org.apache.spark.util.SparkClassUtils */ @Stable -@JsonSerialize(using = classOf[DataTypeJsonSerializer]) -@JsonDeserialize(using = classOf[DataTypeJsonDeserializer]) +@JsonSerialize(`using` = classOf[DataTypeJsonSerializer]) +@JsonDeserialize(`using` = classOf[DataTypeJsonDeserializer]) abstract class DataType extends AbstractDataType { /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index cc95d8ee94b02..d4422bc454b6a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -23,6 +23,7 @@ import scala.collection.{immutable, mutable, Map} import scala.util.Try import scala.util.control.NonFatal +import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.SparkIllegalArgumentException @@ -399,7 +400,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru fields.foreach(field => field.buildFormattedString(prefix, stringConcat, maxDepth)) } - override private[sql] def jsonValue = + override private[sql] def jsonValue: JValue = ("type" -> typeName) ~ ("fields" -> map(_.jsonValue)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala index 1b2013d87eedf..689b4aee87937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala @@ -63,7 +63,7 @@ abstract class StructFilters(pushedFilters: Seq[sources.Filter], schema: StructT val reducedExpr = filters .sortBy(_.references.length) .flatMap(filterToExpression(_, toRef)) - .reduce(And) + .reduce(And(_, _)) Predicate.create(reducedExpr) } @@ -122,15 +122,15 @@ object StructFilters { case sources.Or(left, right) => zip(translate(left), translate(right)).map(Or.tupled) case sources.Not(child) => - translate(child).map(Not) + translate(child).map(Not(_)) case sources.EqualTo(attribute, value) => zipAttributeAndValue(attribute, value).map(EqualTo.tupled) case sources.EqualNullSafe(attribute, value) => zipAttributeAndValue(attribute, value).map(EqualNullSafe.tupled) case sources.IsNull(attribute) => - toRef(attribute).map(IsNull) + toRef(attribute).map(IsNull(_)) case sources.IsNotNull(attribute) => - toRef(attribute).map(IsNotNull) + toRef(attribute).map(IsNotNull(_)) case sources.In(attribute, values) => val literals = values.toImmutableArraySeq.flatMap(toLiteral) if (literals.length == values.length) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beef278a3dfe2..aaff6dedfefed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1380,11 +1380,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // an UnresolvedAttribute. EqualNullSafe( UnresolvedAttribute.quoted(attr.name), - Cast(Literal(value), attr.dataType)) + Cast(Literal(value), attr.dataType) + ): BinaryOperator case None => throw QueryCompilationErrors.missingStaticPartitionColumn(name) } - }.reduce(And) + }.reduce(And(_, _)) } } } @@ -3514,7 +3515,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val inputNullCheck = inputPrimitivesPair.collect { case (isPrimitive, input) if isPrimitive && input.nullable => IsNull(input) - }.reduceLeftOption[Expression](Or) + }.reduceLeftOption(Or.apply) if (inputNullCheck.isDefined) { // Once we add an `If` check above the udf, it is safe to mark those checked inputs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NaturalAndUsingJoinResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NaturalAndUsingJoinResolution.scala index 2e02d957013f4..b1144c75b4e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NaturalAndUsingJoinResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NaturalAndUsingJoinResolution.scala @@ -68,7 +68,7 @@ object NaturalAndUsingJoinResolution extends DataTypeErrorsBase with SQLConfHelp ) val joinPairs = leftKeys.zip(rightKeys) - val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) + val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And(_, _)) // the output list looks like: join keys, columns from left, columns from right val (output, hiddenOutput) = computeOutputAndHiddenOutput( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 7e2cf4f29807c..6dd0adf229d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -380,8 +380,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper val (targetPredicates, joinPredicates) = predicates.partition { predicate => predicate.references.subsetOf(targetTable.outputSet) } - val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral) - val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral) + val targetCond = targetPredicates.reduceOption(And(_, _)).getOrElse(TrueLiteral) + val joinCond = joinPredicates.reduceOption(And(_, _)).getOrElse(TrueLiteral) (Filter(targetCond, targetTable), joinCond) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala index eae7d5a74dbc2..bc0c5484f9f4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala @@ -54,7 +54,7 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { override def transform: PartialFunction[Expression, Expression] = { val transforms = rules.map(_.transform) - Function.unlift { e: Expression => + Function.unlift { (e: Expression) => val result = transforms.foldLeft(e) { case (current, transform) => transform.applyOrElse(current, identity[Expression]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 1d6656fc6426f..9b266da9c0de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -45,10 +45,10 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] { private def canonicalizer = { if (!conf.caseSensitiveAnalysis) { // scalastyle:off caselocale - s: String => s.toLowerCase + (s: String) => s.toLowerCase // scalastyle:on caselocale } else { - s: String => s + (s: String) => s } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala index 6a3d8e161f2c3..8ec1db53a10b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.HashMap -import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, RelationResolution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.LookupCatalog import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 8960c7345521c..13b75f95ef200 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -205,7 +205,7 @@ object ExternalCatalogUtils { nonPartitionPruningPredicates) } - Predicate.createInterpreted(predicates.reduce(And).transform { + Predicate.createInterpreted(predicates.reduce(And(_, _)).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 5c4e9d4bddc5f..a83424c0d2a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedLeafNode, ViewSchemaMode} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, AttributeSeq, Cast, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -1016,7 +1016,7 @@ object CatalogColumnStat extends Logging { } -case class CatalogTableType private(name: String) +case class CatalogTableType(name: String) object CatalogTableType { val EXTERNAL = new CatalogTableType("EXTERNAL") val MANAGED = new CatalogTableType("MANAGED") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 5444ab6845867..b8329c70d6dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -54,7 +54,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility - s: String => new java.math.BigDecimal(s) + (s: String) => new java.math.BigDecimal(s) } else { ExprUtils.getDecimalParser(options.locale) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala index c226e48c6be5e..4e8ca23e1ef8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf """, since = "1.0.0", group = "conditional_funcs") -case class Between private(input: Expression, lower: Expression, upper: Expression, replacement: Expression) +case class Between private (input: Expression, lower: Expression, upper: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(input: Expression, lower: Expression, upper: Expression) = { this(input, lower, upper, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index b65576403e9d8..b88d1f780ae3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -46,7 +46,7 @@ case class DynamicPruningSubquery( broadcastKeyIndices: Seq[Int], onlyInBroadcast: Boolean, exprId: ExprId = NamedExpression.newExprId, - hint: Option[HintInfo] = None) + override val hint: Option[HintInfo] = None) extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint) with DynamicPruning with Unevaluable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index cc6fea2f1b7f1..177f2856a0ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -183,7 +183,7 @@ class ExpressionSet protected( /** Returns a length limited string that must be used for logging only. */ def simpleString(maxFields: Int): String = { - val customToString = { e: Expression => e.simpleString(maxFields) } + val customToString = (e: Expression) => e.simpleString(maxFields) SparkStringUtils.truncatedString( seq = originals.toSeq, start = "Set(", sep = ", ", end = ")", maxFields, Some(customToString)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index bfd3bc8051dff..288df829323e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -83,7 +83,6 @@ case class FunctionTableSubqueryArgumentExpression( copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]) : FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs) - override def hint: Option[HintInfo] = None override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = copy() override def toString: String = s"table-argument#${exprId.id} $conditionString" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index bba3d4b1a806b..940a1ceefa612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -131,10 +131,10 @@ case class ScalaUDF( private def catalystConverter: Any => Any = outputEncoder.map { enc => val toRow = enc.createSerializer().asInstanceOf[Any => Any] if (enc.isSerializedAsStructForTopLevel) { - value: Any => + (value: Any) => if (value == null) null else toRow(value).asInstanceOf[InternalRow] } else { - value: Any => + (value: Any) => if (value == null) null else toRow(value).asInstanceOf[InternalRow].get(0, dataType) } }.getOrElse(createToCatalystConverter(dataType)) @@ -165,10 +165,10 @@ case class ScalaUDF( val unwrappedValueClass = enc.isSerializedAsStruct && enc.schema.fields.length == 1 && enc.schema.fields.head.dataType == dataType val converter = if (enc.isSerializedAsStructForTopLevel && !unwrappedValueClass) { - row: Any => fromRow(row.asInstanceOf[InternalRow]) + (row: Any) => fromRow(row.asInstanceOf[InternalRow]) } else { val inputRow = new GenericInternalRow(1) - value: Any => inputRow.update(0, value); fromRow(inputRow) + (value: Any) => inputRow.update(0, value); fromRow(inputRow) } (converter, true) } else { // use CatalystTypeConverters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index dd2d6c2cb610c..bc154e44b2438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -42,7 +42,7 @@ object SchemaPruning extends SQLConfHelper { // in the resulting schema may differ from their ordering in the logical relation's // original schema val mergedSchema = requestedRootFields - .map { root: RootField => StructType(Array(root.field)) } + .map(root => StructType(Array(root.field))) .reduceLeft(_ merge _) val mergedDataSchema = StructType(schema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 6cfcde5f52dae..187e55bb5dffe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -47,7 +47,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => protected def useDecimalPlainString: Boolean - protected val binaryFormatter: BinaryFormatter = UTF8String.fromBytes + protected val binaryFormatter: BinaryFormatter = UTF8String.fromBytes(_) // Makes the function accept Any type input by doing `asInstanceOf[T]`. @inline private def acceptAny[T](func: T => UTF8String): Any => UTF8String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 758ef22f0a2c2..f21b546796296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -94,7 +94,11 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate ) } else { Seq( - /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + /* count = */ If( + nullableChildren.map(IsNull(_): Predicate).reduce(Or(_, _)), + count, + count + 1L + ) ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 6767731fc25bf..97b7f067c1507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1853,8 +1853,8 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes { // scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTimestamp { override val func = DateTimeUtils.fromUTCTime - override val funcName: String = "fromUTCTime" - override val prettyName: String = "from_utc_timestamp" + override val funcName = "fromUTCTime" + override val prettyName = "from_utc_timestamp" override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): FromUTCTimestamp = copy(left = newLeft, right = newRight) @@ -1887,8 +1887,8 @@ case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTime // scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimestamp { override val func = DateTimeUtils.toUTCTime - override val funcName: String = "toUTCTime" - override val prettyName: String = "to_utc_timestamp" + override val funcName = "toUTCTime" + override val prettyName = "to_utc_timestamp" override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ToUTCTimestamp = copy(left = newLeft, right = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index 641f22ba3f786..5d8849d212706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -62,7 +62,7 @@ private[this] object JsonPathParser extends RegexParsers { // parse `[*]` and `[123]` subscripts def subscript: Parser[List[PathInstruction]] = for { - operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index) <~ ']' + operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index.apply) <~ ']' } yield { Subscript :: operand :: Nil } @@ -301,10 +301,11 @@ case class JsonTupleEvaluator(foldableFieldNames: Array[Option[String]]) { // SPARK-21804: json_tuple returns null values within repeated columns // except the first one; so that we need to check the remaining fields. - do { + while ({ row(idx) = jsonValue idx = fieldNames.indexOf(jsonField, idx + 1) - } while (idx >= 0) + idx >= 0 + }) () } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 90ab8725e553e..f1b1ba21e968f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -911,7 +911,7 @@ object MapObjects { * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ -case class MapObjects private( +case class MapObjects private ( loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression, @@ -1353,7 +1353,7 @@ object CatalystToExternalMap { * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ -case class CatalystToExternalMap private( +case class CatalystToExternalMap private[expressions] ( keyLoopVar: LambdaVariable, keyLambdaFunction: Expression, valueLoopVar: LambdaVariable, @@ -1533,13 +1533,13 @@ object ExternalMapToCatalyst { * format. * @param inputData An expression that when evaluated returns the input map object. */ -case class ExternalMapToCatalyst private( +case class ExternalMapToCatalyst private[catalyst] ( keyLoopVar: LambdaVariable, keyConverter: Expression, valueLoopVar: LambdaVariable, valueConverter: Expression, - inputData: Expression) - extends Expression with NonSQLExpression { + inputData: Expression +) extends Expression with NonSQLExpression { override def foldable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 210b7f8fb5306..aca4cf75f4ce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -78,7 +78,7 @@ abstract class SubqueryExpression( outerAttrs: Seq[Expression], exprId: ExprId, joinCond: Seq[Expression], - hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] { + val hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] { override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) @@ -86,7 +86,6 @@ abstract class SubqueryExpression( override def withNewPlan(plan: LogicalPlan): SubqueryExpression def withNewOuterAttrs(outerAttrs: Seq[Expression]): SubqueryExpression def isCorrelated: Boolean = outerAttrs.nonEmpty - def hint: Option[HintInfo] def withNewHint(hint: Option[HintInfo]): SubqueryExpression } @@ -397,7 +396,7 @@ case class ScalarSubquery( outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None, + override val hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, needSingleJoin: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { @@ -476,7 +475,7 @@ case class LateralSubquery( outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + override val hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = plan.output.toStructType override def nullable: Boolean = true @@ -522,7 +521,7 @@ case class ListQuery( // number of the columns of the original plan, to report the data type properly. numCols: Int = -1, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + override val hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { def childOutputs: Seq[Attribute] = plan.output.take(numCols) override def dataType: DataType = if (numCols > 1) { @@ -593,7 +592,7 @@ case class Exists( outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + override val hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index ab787663c9923..f3a717d207b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -786,7 +786,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) private lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = result :: count :: Nil - override lazy val initialValues: Seq[Literal] = Seq( + override lazy val initialValues = Seq( /* result = */ default, /* count = */ Literal(1L) ) @@ -936,7 +936,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow NoOp ) - override val evaluateExpression = bucket + override val evaluateExpression: Expression = bucket override protected def withNewChildInternal( newChild: Expression): NTile = copy(buckets = newChild) @@ -959,8 +959,8 @@ abstract class RankLike extends AggregateWindowFunction { /** Predicate that detects if the order attributes have changed. */ protected val orderEquals = children.zip(orderAttrs) - .map(EqualNullSafe.tupled) - .reduceOption(And) + .map(EqualNullSafe.tupled(_): Predicate) + .reduceOption(And(_, _)) .getOrElse(Literal(true)) protected val orderInit = children.map(e => Literal.create(null, e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 471f0bd554105..23a1c55332c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -304,7 +304,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } else { (otherPlan, onePlan) } - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And), JoinHint.NONE) + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And(_, _)), JoinHint.NONE) val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 47cee2e789c7c..d6efb856f4f60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -620,7 +620,7 @@ object DecorrelateInnerQuery extends PredicateHelper { val newFilterCond = newCorrelated ++ uncorrelated val newFilter = newFilterCond match { case Nil => newChild - case conditions => Filter(conditions.reduce(And), newChild) + case conditions => Filter(conditions.reduce(And(_, _)), newChild) } // Equality predicates are used as join conditions with the outer query. val newJoinCond = joinCond ++ equalityCond @@ -635,7 +635,7 @@ object DecorrelateInnerQuery extends PredicateHelper { val newOuterReferenceMap = outerReferenceMap ++ equivalences val newFilter = uncorrelated match { case Nil => newChild - case conditions => Filter(conditions.reduce(And), newChild) + case conditions => Filter(conditions.reduce(And(_, _)), newChild) } val newJoinCond = joinCond ++ correlated (newFilter, newJoinCond, newOuterReferenceMap) @@ -904,7 +904,7 @@ object DecorrelateInnerQuery extends PredicateHelper { // Use the current join conditions returned from the recursive call as the join // conditions for the left outer join. All outer references in the join // conditions are replaced by the newly created domain attributes. - val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And) + val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And(_, _)) val domainJoin = DomainJoin(domainAttrs, agg, LeftOuter, condition) // Original domain attributes preserved through Aggregate are no longer needed. val newProjectList = projectList.filter(!referencesToAdd.contains(_)) @@ -1031,7 +1031,7 @@ object DecorrelateInnerQuery extends PredicateHelper { case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) } val newCondition = (newCorrelated ++ uncorrelated - ++ augmentedConditions).reduceOption(And) + ++ augmentedConditions).reduceOption(And(_, _)) val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) (newJoin, newJoinCond, newOuterReferenceMap) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 776efbed273e3..85d322ca7ec98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -78,7 +78,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { case (l, r) => EqualTo(l, r) } ++ condition - j.copy(condition = Some(newConditions.reduce(And))) + j.copy(condition = Some(newConditions.reduce(And(_, _)))) // TODO: ideally Aggregate should also be handled here, but its grouping expressions are // mixed in its aggregate expressions. It's unreliable to change the grouping expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3727b3ea19ed9..e83c6baed0aab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -403,7 +403,8 @@ abstract class Optimizer(catalogManager: CatalogManager) } } val newPlan = Project(projections, - if (predicates.nonEmpty) Filter(predicates.reduce(And), optimizedAgg) else optimizedAgg + if (predicates.nonEmpty) Filter(predicates.reduce(And(_, _)), optimizedAgg) + else optimizedAgg ) val needTopLevelProject = newPlan.output.zip(optimizedAgg.output).exists { case (oldAttr, newAttr) => oldAttr.exprId != newAttr.exprId @@ -1231,7 +1232,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { } private object SimplifyExtractValueExecutor extends RuleExecutor[LogicalPlan] { - override val batches = Batch("SimplifyExtractValueOps", FixedPoint(10), + override val batches: Seq[Batch] = Batch("SimplifyExtractValueOps", FixedPoint(10), SimplifyExtractValueOps, // `SimplifyExtractValueOps` turns map lookup to CaseWhen, and we need the following two rules // to further optimize CaseWhen. @@ -1473,7 +1474,7 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] { ) -- generate.child.constraints if (inferredFilters.nonEmpty) { - generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + generate.copy(child = Filter(inferredFilters.reduce(And(_, _)), generate.child)) } else { generate } @@ -1515,7 +1516,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) if (newFilters.nonEmpty) { - Filter(And(newFilters.reduce(And), condition), child) + Filter(And(newFilters.reduce(And(_, _)), condition), child) } else { filter } @@ -1564,7 +1565,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] if (newPredicates.isEmpty) { plan } else { - Filter(newPredicates.reduce(And), plan) + Filter(newPredicates.reduce(And(_, _)), plan) } } } @@ -1655,13 +1656,13 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { val (combineCandidates, rest) = splitConjunctivePredicates(fc).partition(p => p.deterministic && !p.throwable) val mergedFilter = (ExpressionSet(combineCandidates) -- - ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { + ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And(_, _)) match { case Some(ac) => Filter(And(nc, ac), grandChild) case None => nf } - rest.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter) + rest.reduceOption(And(_, _)).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter) } } @@ -1785,7 +1786,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { } else if (remainingPredicates.isEmpty) { p } else { - val newCond = remainingPredicates.reduce(And) + val newCond = remainingPredicates.reduce(And(_, _)) Filter(newCond, p) } } @@ -1845,12 +1846,12 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe } if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) + val pushDownPredicate = pushDown.reduce(And(_, _)) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) // If there is no more filter to stay up, just eliminate the filter. // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". - if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And(_, _)), newAggregate) } else { filter } @@ -1874,9 +1875,9 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) + val pushDownPredicate = pushDown.reduce(And(_, _)) val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And(_, _)), newWindow) } else { filter } @@ -1886,7 +1887,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic) if (pushDown.nonEmpty) { - val pushDownCond = pushDown.reduceLeft(And) + val pushDownCond = pushDown.reduceLeft(And(_, _)) // The union is the child of the filter so it's children are grandchildren. // Moves filters down to the grandchild if there is an element in the grand child's // output which is semantically equal to the filter being evaluated. @@ -1902,7 +1903,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe val newUnion = union.withNewChildren(newGrandChildren) if (stayUp.nonEmpty) { // If there is any filter we can't push evaluate them post union - Filter(stayUp.reduceLeft(And), newUnion) + Filter(stayUp.reduceLeft(And(_, _)), newUnion) } else { // If we pushed all filters then just return the new union. newUnion @@ -1918,11 +1919,11 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe } if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduceLeft(And) + val pushDownPredicate = pushDown.reduceLeft(And(_, _)) val newWatermark = watermark.copy(child = Filter(pushDownPredicate, watermark.child)) // If there is no more filter to stay up, just eliminate the filter. // Otherwise, create "Filter(stayUp) <- watermark <- Filter(pushDownPredicate)". - if (stayUp.isEmpty) newWatermark else Filter(stayUp.reduceLeft(And), newWatermark) + if (stayUp.isEmpty) newWatermark else Filter(stayUp.reduceLeft(And(_, _)), newWatermark) } else { filter } @@ -1969,9 +1970,9 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { - val newChild = insertFilter(pushDown.reduceLeft(And)) + val newChild = insertFilter(pushDown.reduceLeft(And(_, _))) if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newChild) + Filter(stayUp.reduceLeft(And(_, _)), newChild) } else { newChild } @@ -2042,17 +2043,17 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { case _: InnerLike => // push down the single side `where` condition into respective sides val newLeft = leftFilterConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + reduceLeftOption(And(_, _)).map(Filter(_, left)).getOrElse(left) val newRight = rightFilterConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + reduceLeftOption(And(_, _)).map(Filter(_, right)).getOrElse(right) // don't push throwable expressions into join condition val (newJoinConditions, others) = commonFilterCondition.partition(cond => canEvaluateWithinJoin(cond) && !cond.throwable) - val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) + val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And(_, _)) val join = Join(newLeft, newRight, joinType, newJoinCond, hint) if (others.nonEmpty) { - Filter(others.reduceLeft(And), join) + Filter(others.reduceLeft(And(_, _)), join) } else { join } @@ -2060,22 +2061,22 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the right side only `where` condition val newLeft = left val newRight = rightFilterConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + reduceLeftOption(And(_, _)).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint) (leftFilterConditions ++ commonFilterCondition). - reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) + reduceLeftOption(And(_, _)).map(Filter(_, newJoin)).getOrElse(newJoin) case LeftOuter | LeftSingle | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + reduceLeftOption(And(_, _)).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint) (rightFilterConditions ++ commonFilterCondition). - reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) + reduceLeftOption(And(_, _)).map(Filter(_, newJoin)).getOrElse(newJoin) case other => throw SparkException.internalError(s"Unexpected join type: $other") @@ -2090,26 +2091,26 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { case _: InnerLike | LeftSemi => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + reduceLeftOption(And(_, _)).map(Filter(_, left)).getOrElse(left) val newRight = rightJoinConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - val newJoinCond = commonJoinCondition.reduceLeftOption(And) + reduceLeftOption(And(_, _)).map(Filter(_, right)).getOrElse(right) + val newJoinCond = commonJoinCondition.reduceLeftOption(And(_, _)) Join(newLeft, newRight, joinType, newJoinCond, hint) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + reduceLeftOption(And(_, _)).map(Filter(_, left)).getOrElse(left) val newRight = right - val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And(_, _)) Join(newLeft, newRight, RightOuter, newJoinCond, hint) case LeftOuter | LeftAnti | ExistenceJoin(_) => // push down the right side only join filter for right sub query val newLeft = left val newRight = rightJoinConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + reduceLeftOption(And(_, _)).map(Filter(_, right)).getOrElse(right) + val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And(_, _)) Join(newLeft, newRight, joinType, newJoinCond, hint) // Do not move join predicates of a single join. @@ -2349,8 +2350,9 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { _.containsPattern(INTERSECT), ruleId) { case Intersect(left, right, false) => assert(left.output.size == right.output.size) - val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And), JoinHint.NONE)) + val joinCond = + left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r): Predicate } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And(_, _)), JoinHint.NONE)) } } @@ -2371,8 +2373,9 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { _.containsPattern(EXCEPT), ruleId) { case Except(left, right, false) => assert(left.output.size == right.output.size) - val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And), JoinHint.NONE)) + val joinCond = + left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r): Predicate } + Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And(_, _)), JoinHint.NONE)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index a5fcbe6f16b38..26ccec7039cee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -73,7 +73,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) } val makeJoinCondition = (predicates: Seq[Expression]) => { - replaceAlias(predicates.reduce(And), aliasMap) + replaceAlias(predicates.reduce(And(_, _)), aliasMap) } pushDownJoin(join, canPushDownPredicate, makeJoinCondition) @@ -81,7 +81,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] case join @ Join(w: Window, rightOp, LeftSemiOrAnti(_), _, _) if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet - pushDownJoin(join, _.references.subsetOf(partitionAttrs), _.reduce(And)) + pushDownJoin(join, _.references.subsetOf(partitionAttrs), _.reduce(And(_, _))) // LeftSemi/LeftAnti over Union case Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) @@ -107,7 +107,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _) if PushPredicateThroughNonJoin.canPushThrough(u) && u.expressions.forall(_.deterministic) => val validAttrs = u.child.outputSet ++ rightOp.outputSet - pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And)) + pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And(_, _))) } /** @@ -160,7 +160,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] join.joinType match { // In case of Left semi join, the part of the join condition which does not refer to // to attributes of the grandchild are kept as a Filter above. - case LeftSemi => Filter(stayUp.reduce(And), newPlan) + case LeftSemi => Filter(stayUp.reduce(And(_, _)), newPlan) // In case of left-anti join, the join is pushed down only when the entire join // condition is eligible to be pushed down to preserve the semantics of left-anti join. case _ => join diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala index a2bc0bf83a2c6..1146331d29f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala @@ -60,9 +60,9 @@ object PushExtraPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHel j } else { lazy val newLeft = - leftExtraCondition.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + leftExtraCondition.reduceLeftOption(And(_, _)).map(Filter(_, left)).getOrElse(left) lazy val newRight = - rightExtraCondition.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + rightExtraCondition.reduceLeftOption(And(_, _)).map(Filter(_, right)).getOrElse(right) val newJoin = joinType match { case _: InnerLike | LeftSemi => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala index 11c106cf4a7f5..bfe48c0dcfa30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala @@ -91,7 +91,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] { if (filteredPredicates.isEmpty) { Seq(Literal.TrueLiteral) } else { - preds :+ filteredPredicates.reduce(And) + preds :+ filteredPredicates.reduce(And(_, _)) } } val newAttributes = attrs ++ @@ -128,7 +128,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] { val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty) if (!isTruePredicate(newPreds) && newPreds.exists(newPred => !preds.exists(_.semanticEquals(newPred)))) { - val newCombinedPred = newPreds.reduce(Or) + val newCombinedPred = newPreds.reduce(Or(_, _)) val newChild = if (needsPruning(originalPlan, newAttrSet)) { Project(newAttrSet.toSeq, originalPlan) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e867953bcf282..373101f67f6bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -198,7 +198,7 @@ object ConstantPropagation extends Rule[LogicalPlan] { case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), AttributeMap.empty) + (newChild.map(Not(_)), AttributeMap.empty) case _ => (None, AttributeMap.empty) } @@ -425,11 +425,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) if (ldiff.isEmpty || rdiff.isEmpty) { // (a || b || c || ...) && (a || b) => (a || b) - common.reduce(Or) + common.reduce(Or(_, _)) } else { // (a || b || c || ...) && (a || b || d || ...) => // a || b || ((c || ...) && (d || ...)) - (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + (common :+ And(ldiff.reduce(Or(_, _)), rdiff.reduce(Or(_, _)))).reduce(Or(_, _)) } } else { // No common factors from disjunctive predicates, reduce common factor from conjunction @@ -440,7 +440,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { and } else { // (a && b) && a && (a && c) => a && b && c - buildBalancedPredicate(distinct.toSeq, And) + buildBalancedPredicate(distinct.toSeq, And(_, _)) } } @@ -463,11 +463,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) if (ldiff.isEmpty || rdiff.isEmpty) { // (a && b) || (a && b && c && ...) => a && b - common.reduce(And) + common.reduce(And(_, _)) } else { // (a && b && c && ...) || (a && b && d && ...) => // a && b && ((c && ...) || (d && ...)) - (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) + (common :+ Or(ldiff.reduce(And(_, _)), rdiff.reduce(And(_, _)))).reduce(And(_, _)) } } else { // No common factors in conjunctive predicates, reduce common factor from disjunction @@ -478,7 +478,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { or } else { // (a || b) || a || (a || c) => a || b || c - buildBalancedPredicate(distinct.toSeq, Or) + buildBalancedPredicate(distinct.toSeq, Or(_, _)) } } @@ -792,16 +792,16 @@ object LikeSimplification extends Rule[LogicalPlan] with PredicateHelper { } else { multi match { case l: LikeAll => - val and = buildBalancedPredicate(replacements, And) + val and = buildBalancedPredicate(replacements, And(_, _)) if (remainPatterns.nonEmpty) And(and, l.copy(patterns = remainPatterns)) else and case l: NotLikeAll => - val and = buildBalancedPredicate(replacements.map(Not(_)), And) + val and = buildBalancedPredicate(replacements.map(Not(_)), And(_, _)) if (remainPatterns.nonEmpty) And(and, l.copy(patterns = remainPatterns)) else and case l: LikeAny => - val or = buildBalancedPredicate(replacements, Or) + val or = buildBalancedPredicate(replacements, Or(_, _)) if (remainPatterns.nonEmpty) Or(or, l.copy(patterns = remainPatterns)) else or case l: NotLikeAny => - val or = buildBalancedPredicate(replacements.map(Not(_)), Or) + val or = buildBalancedPredicate(replacements.map(Not(_)), Or(_, _)) if (remainPatterns.nonEmpty) Or(or, l.copy(patterns = remainPatterns)) else or } } @@ -930,9 +930,9 @@ object NullDownPropagation extends Rule[LogicalPlan] { case q: LogicalPlan => q.transformExpressionsDownWithPruning( _.containsPattern(NULL_CHECK), ruleId) { case IsNull(e) if e.nullIntolerant && supportedNullIntolerant(e) => - e.children.map(IsNull(_): Expression).reduceLeft(Or) + e.children.map(IsNull(_): Expression).reduceLeft(Or(_, _)) case IsNotNull(e) if e.nullIntolerant && supportedNullIntolerant(e) => - e.children.map(IsNotNull(_): Expression).reduceLeft(And) + e.children.map(IsNotNull(_): Expression).reduceLeft(And(_, _)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 0fbfce5962c73..6a48f6e0e4a10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -22,7 +22,7 @@ import java.time.{Instant, LocalDateTime, ZoneId} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.{CurrentUserContext, InternalRow} -import org.apache.spark.sql.catalyst.analysis.{CastSupport, ResolvedInlineTable} +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, CastSupport, ResolvedInlineTable} import org.apache.spark.sql.catalyst.analysis.ResolveInlineTables.prepareForEval import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 44875bfb3ec22..cd4ca31f512a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -64,9 +64,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { case (_, _) => Cross } val join = Join(left, right, innerJoinType, - joinConditions.reduceLeftOption(And), JoinHint.NONE) + joinConditions.reduceLeftOption(And(_, _)), JoinHint.NONE) if (others.nonEmpty) { - Filter(others.reduceLeft(And), join) + Filter(others.reduceLeft(And(_, _)), join) } else { join } @@ -88,7 +88,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val (joinConditions, others) = conditions.partition( e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) val joined = Join(left, right, innerJoinType, - joinConditions.reduceLeftOption(And), JoinHint.NONE) + joinConditions.reduceLeftOption(And(_, _)), JoinHint.NONE) // should not have reference to same logical plan createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) @@ -270,11 +270,11 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat log" it will be moved out and the join plan will be turned to cross join.") None } else { - Some(rest.reduceLeft(And)) + Some(rest.reduceLeft(And(_, _))) } val newJoin = j.copy(condition = newCondition) joinType match { - case _: InnerLike => Filter(udf.reduceLeft(And), newJoin) + case _: InnerLike => Filter(udf.reduceLeft(And(_, _)), newJoin) case _ => throw QueryCompilationErrors.usePythonUDFInJoinConditionUnsupportedError(joinType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 378081221c8c1..d40a4b6fee780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -147,7 +147,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { case Nil => child - case conditions => Filter(conditions.reduce(And), child) + case conditions => Filter(conditions.reduce(And(_, _)), child) } // Filter the plan by applying left semi and left anti joins. @@ -194,7 +194,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 - val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And(_, _)) Join(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, newSub, Some(finalJoinCond)), LeftAnti, Option(finalJoinCond), JoinHint(None, subHint)) case (p, predicate) => @@ -403,7 +403,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case Exists(sub, _, _, conditions, subHint) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val existenceJoin = ExistenceJoin(exists) - val newCondition = conditions.reduceLeftOption(And) + val newCondition = conditions.reduceLeftOption(And(_, _)) newPlan = buildJoin(newPlan, rewriteDomainJoinsIfPresent(newPlan, sub, newCondition), existenceJoin, newCondition, subHint) @@ -427,7 +427,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // :- Relation[id#78,v#79] parquet // +- Relation[id#80] parquet val nullAwareJoinConds = inConditions.map(c => Or(c, IsNull(c))) - val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And(_, _)) val joinHint = JoinHint(None, subHint) newPlan = Join(newPlan, rewriteDomainJoinsIfPresent(newPlan, newSub, Some(finalJoinCond)), @@ -439,7 +439,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) - val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And(_, _)) val joinHint = JoinHint(None, subHint) newPlan = Join(newPlan, rewriteDomainJoinsIfPresent(newPlan, newSub, newConditions), ExistenceJoin(exists), newConditions, joinHint) @@ -447,7 +447,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { exists } } - (newExprs.reduceOption(And), newPlan, introducedAttrs.toSeq) + (newExprs.reduceOption(And(_, _)), newPlan, introducedAttrs.toSeq) } } @@ -488,7 +488,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper correlated match { case Nil => f case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) + val newFilter = Filter(local.reduce(And(_, _)), child) predicateMap += newFilter -> xs newFilter case xs => @@ -913,7 +913,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe } lazy val planWithoutCountBug = Project( currentChild.output :+ origOutput, - Join(currentChild, query, joinType, conditions.reduceOption(And), joinHint)) + Join(currentChild, query, joinType, conditions.reduceOption(And(_, _)), joinHint)) if (Utils.isTesting) { assert(mayHaveCountBug.isDefined) @@ -962,7 +962,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ subqueryResultExpr, Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), - joinType, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And(_, _)), joinHint)) } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. @@ -994,7 +994,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ caseExpr, Join(currentChild, Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - joinType, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And(_, _)), joinHint)) } } } @@ -1087,7 +1087,7 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] { _.containsPattern(LATERAL_JOIN)) { case LateralJoin(left, LateralSubquery(sub, _, _, joinCond, subHint), joinType, condition) => val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond) - val newCond = (condition ++ joinCond).reduceOption(And) + val newCond = (condition ++ joinCond).reduceOption(And(_, _)) // The subquery appears on the right side of the join, hence add the hint to the right side Join(left, newRight, joinType, newCond, JoinHint(None, subHint)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3366237e81ad1..97d4d259b716c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1792,13 +1792,17 @@ class AstBuilder extends DataTypeAstBuilder allowNamedGroupingExpressions: Boolean): LogicalPlan = withOrigin(ctx) { if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) { val groupByExpressions: Seq[Expression] = - ctx.groupingExpressions.asScala.map { n: NamedExpressionContext => - if (!allowNamedGroupingExpressions && (n.name != null || n.identifierList != null)) { + ctx.groupingExpressions.asScala.map { expression => + if ( + !allowNamedGroupingExpressions && + (expression.name != null || expression.identifierList != null) + ) { // If we do not allow grouping expressions to have aliases in this context, we throw a // syntax error here accordingly. - val error: String = (if (n.name != null) n.name else n.identifierList).getText + val error = + (if (expression.name != null) expression.name else expression.identifierList).getText throw new ParseException( - command = Some(SparkParserUtils.command(n)), + command = Some(SparkParserUtils.command(expression)), start = Origin(), errorClass = "PARSE_SYNTAX_ERROR", messageParameters = Map( @@ -1806,7 +1810,7 @@ class AstBuilder extends DataTypeAstBuilder "hint" -> s": extra input '$error'"), queryContext = Array.empty) } - visitNamedExpression(n) + visitNamedExpression(expression) }.toSeq if (ctx.GROUPING != null) { // GROUP BY ... GROUPING SETS (...) @@ -2737,7 +2741,7 @@ class AstBuilder extends DataTypeAstBuilder } } else { ctx.expression.asScala.map(expression) - .map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(Or) + .map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(Or(_, _)) } case Some(SqlBaseParser.ALL) => validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) @@ -2753,7 +2757,7 @@ class AstBuilder extends DataTypeAstBuilder } } else { ctx.expression.asScala.map(expression) - .map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(And) + .map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(And(_, _)) } case _ => val escapeChar = Option(ctx.escapeChar) @@ -5973,7 +5977,7 @@ class AstBuilder extends DataTypeAstBuilder * }}} */ override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) { - val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure) + val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure(_)) val args = ctx.functionArgument.asScala.map { case expr if expr.namedArgumentExpression != null => val namedExpr = expr.namedArgumentExpression @@ -6146,8 +6150,8 @@ class AstBuilder extends DataTypeAstBuilder // it to generate clear error messages if the expression contains any aggregate functions // (this is not allowed in the EXTEND operator). val extendExpressions: Seq[NamedExpression] = - Option(ctx.extendList).map { n: NamedExpressionSeqContext => - visitNamedExpressionSeq(n).map { + Option(ctx.extendList).map { expressionSeq => + visitNamedExpressionSeq(expressionSeq).map { case (a: Alias, _) => a.copy( child = PipeExpression(a.child, isAggregate = false, PipeOperators.extendClause))( @@ -6249,8 +6253,8 @@ class AstBuilder extends DataTypeAstBuilder // Visit each aggregate expression, and add a [[PipeExpression]] on top of it to generate // clear error messages if the expression does not contain at least one aggregate function. val aggregateExpressions: Seq[NamedExpression] = - Option(ctx.namedExpressionSeq()).map { n: NamedExpressionSeqContext => - visitNamedExpressionSeq(n).map { + Option(ctx.namedExpressionSeq()).map { expressionSeq => + visitNamedExpressionSeq(expressionSeq).map { case (a: Alias, _) => a.copy(child = PipeExpression(a.child, isAggregate = true, PipeOperators.aggregateClause))( @@ -6260,14 +6264,18 @@ class AstBuilder extends DataTypeAstBuilder PipeExpression(e, isAggregate = true, PipeOperators.aggregateClause), aliasFunc) } }.getOrElse(Seq.empty) - Option(ctx.aggregationClause()).map { c: AggregationClauseContext => - withAggregationClause(c, aggregateExpressions, left, allowNamedGroupingExpressions = true) + Option(ctx.aggregationClause()).map { clause => + withAggregationClause( + clause, + aggregateExpressions, left, + allowNamedGroupingExpressions = true + ) match { case a: Aggregate => // GROUP BY ALL, GROUP BY CUBE, GROUPING_ID, GROUPING SETS, and GROUP BY ROLLUP are not // supported yet. def error(s: String): Unit = - throw QueryParsingErrors.pipeOperatorAggregateUnsupportedCaseError(s, c) + throw QueryParsingErrors.pipeOperatorAggregateUnsupportedCaseError(s, clause) a.groupingExpressions match { case Seq(key: UnresolvedAttribute) if key.equalsIgnoreCase("ALL") => error("GROUP BY ALL") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 54a4e75c90c95..09c71fd43e237 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -219,8 +219,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if (joinKeys.nonEmpty) { val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), - predicatesOfJoinKeys.reduceOption(And), left, right, hint)) + Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And(_, _)), + predicatesOfJoinKeys.reduceOption(And(_, _)), left, right, hint)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index 1651003dd7744..f679e4f064aaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -126,7 +126,7 @@ object NormalizePlan extends PredicateHelper { splitConjunctivePredicates(condition) .map(rewriteBinaryComparison) .sortBy(_.hashCode()) - .reduce(And), + .reduce(And(_, _)), child ) case sample: Sample => @@ -143,7 +143,7 @@ object NormalizePlan extends PredicateHelper { splitConjunctivePredicates(condition.get) .map(rewriteBinaryComparison) .sortBy(_.hashCode()) - .reduce(And) + .reduce(And(_, _)) Join(left, right, newJoinType, Some(newCondition), hint) case project: Project if project @@ -184,8 +184,8 @@ object NormalizePlan extends PredicateHelper { * 3. (a > b), (b < a) */ private def rewriteBinaryComparison(condition: Expression): Expression = condition match { - case EqualTo(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case EqualNullSafe(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case EqualTo(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo(_, _)) + case EqualNullSafe(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe(_, _)) case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1269c9bf8ca1e..b5928f0a5e5d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -751,7 +751,7 @@ object QueryPlan extends PredicateHelper { */ def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { if (predicates.nonEmpty) { - val normalized = normalizeExpressions(predicates.reduce(And), output) + val normalized = normalizeExpressions(predicates.reduce(And(_, _)), output) splitConjunctivePredicates(normalized) } else { Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala index ec7df606632a1..77dbcb3ff7d98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala @@ -89,7 +89,7 @@ case class ColumnDefinition( } private def encodeIdentityColumnSpec(metadataBuilder: MetadataBuilder): Unit = { - identityColumnSpec.foreach { spec: IdentityColumnSpec => + identityColumnSpec.foreach { (spec: IdentityColumnSpec) => metadataBuilder.putLong(IdentityColumn.IDENTITY_INFO_START, spec.getStart) metadataBuilder.putLong(IdentityColumn.IDENTITY_INFO_STEP, spec.getStep) metadataBuilder.putBoolean( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index ef035eba5922c..7cb00da7473eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -101,7 +101,7 @@ trait ConstraintHelper { // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output val nonNullableAttributes = output.filterNot(_.nullable) - isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull) + isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull(_)) isNotNullConstraints -- constraints } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 60f4453ca23fc..61bdd06d2a096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -535,7 +535,7 @@ abstract class UnionBase extends LogicalPlan { val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) val others = (othera.keySet intersect otherb.keySet).map { attr => - Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + Or(othera(attr).reduceLeft(And(_, _)), otherb(attr).reduceLeft(And(_, _))) } common ++ others } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 4b8556b1bb5de..8a7fd8b11783d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -564,7 +564,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] val afterRuleOnChildren = mapChildren(_.transformUpWithBeforeAndAfterRuleOnChildren(cond, ruleId)(rule)) val newNode = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse((this, afterRuleOnChildren), { t: (BaseType, BaseType) => t._2 }) + rule.applyOrElse((this, afterRuleOnChildren), (t: (BaseType, BaseType)) => t._2) } if (this eq newNode) { this.markRuleAsIneffective(ruleId) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 1084e99731510..c5f309aafe955 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -335,7 +335,7 @@ object PhysicalStringType { case class PhysicalArrayType( elementType: DataType, containsNull: Boolean) extends PhysicalDataType { override private[sql] type InternalType = ArrayData - override private[sql] def ordering = interpretedOrdering + override private[sql] def ordering: Ordering[InternalType] = interpretedOrdering @transient private[sql] lazy val tag = typeTag[InternalType] @transient @@ -397,7 +397,7 @@ class PhysicalVariantType extends PhysicalDataType { @transient private[sql] lazy val tag = typeTag[InternalType] // TODO(SPARK-45891): Support comparison for the Variant type. - override private[sql] def ordering = + override private[sql] def ordering: Ordering[InternalType] = throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( "PhysicalVariantType") } @@ -405,7 +405,7 @@ class PhysicalVariantType extends PhysicalDataType { object PhysicalVariantType extends PhysicalVariantType object UninitializedPhysicalType extends PhysicalDataType { - override private[sql] def ordering = + override private[sql] def ordering: Ordering[InternalType] = throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( "UninitializedPhysicalType") override private[sql] type InternalType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 2dca2eed85731..ec774e588ed77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -403,7 +403,7 @@ object ResolveDefaultColumns extends QueryErrorsBase // if we encounter an unresolved existsDefault private def fallbackResolveExistenceDefaultValue( field: StructField): Expression = { - field.getExistenceDefaultValue().map { defaultSQL: String => + field.getExistenceDefaultValue().map { defaultSQL => logWarning(log"Encountered unresolved exists default value: " + log"'${MDC(COLUMN_DEFAULT_VALUE, defaultSQL)}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala index b66658467c1b7..ce9ce40a66122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala @@ -133,9 +133,10 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali char match { case ZERO_DIGIT => val prevI = i - do { + while ({ i += 1 - } while (i < len && (numberFormat(i) == ZERO_DIGIT || numberFormat(i) == NINE_DIGIT)) + i < len && (numberFormat(i) == ZERO_DIGIT || numberFormat(i) == NINE_DIGIT) + }) () if (reachedDecimalPoint) { tokens.append(AtMostAsManyDigits(i - prevI)) } else { @@ -143,9 +144,10 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali } case NINE_DIGIT => val prevI = i - do { + while ({ i += 1 - } while (i < len && (numberFormat(i) == ZERO_DIGIT || numberFormat(i) == NINE_DIGIT)) + i < len && (numberFormat(i) == ZERO_DIGIT || numberFormat(i) == NINE_DIGIT) + }) () tokens.append(AtMostAsManyDigits(i - prevI)) case POINT_SIGN | POINT_LETTER => tokens.append(DecimalPoint()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 5d267143b06c9..9fcfc8d19c76b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -132,7 +132,7 @@ object StaxXmlParserUtils { options: XmlOptions): String = { val xmlString = new StringBuilder() var indent = 0 - do { + while ({ parser.nextEvent match { case e: StartElement => xmlString.append('<').append(e.getName) @@ -153,12 +153,13 @@ object StaxXmlParserUtils { xmlString.append(c.getData) case _: XMLEvent => // do nothing } - } while (parser.peek() match { - case _: EndElement => - // until the unclosed end element for the whole parent is found - indent > 0 - case _ => true - }) + parser.peek() match { + case _: EndElement => + // until the unclosed end element for the whole parent is found + indent > 0 + case _ => true + } + }) () skipNextEndElement(parser, startElementName, options) xmlString.toString() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala index 043e2b0137862..4037d7cc8d55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala @@ -52,7 +52,7 @@ object CalendarIntervalBenchmark extends BenchmarkBase { val unsafeRow = UnsafeProjection.create(Array[DataType](CalendarIntervalType)).apply(row) val benchmark = new Benchmark(name, iters * numRows.toLong, output = output) - benchmark.addCase("Call setInterval & getInterval") { _: Int => + benchmark.addCase("Call setInterval & getInterval") { (_: Int) => for (_ <- 0L until iters) { var i = 0 while (i < numRows) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index e515b771c96c6..a54fb30c4c0dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -52,7 +52,7 @@ object HashBenchmark extends BenchmarkBase { ).toArray val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output) - benchmark.addCase("interpreted version") { _: Int => + benchmark.addCase("interpreted version") { (_: Int) => var sum = 0 for (_ <- 0L until iters) { var i = 0 @@ -64,7 +64,7 @@ object HashBenchmark extends BenchmarkBase { } val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) - benchmark.addCase("codegen version") { _: Int => + benchmark.addCase("codegen version") { (_: Int) => var sum = 0 for (_ <- 0L until iters) { var i = 0 @@ -76,7 +76,7 @@ object HashBenchmark extends BenchmarkBase { } val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) - benchmark.addCase("codegen version 64-bit") { _: Int => + benchmark.addCase("codegen version 64-bit") { (_: Int) => var sum = 0 for (_ <- 0L until iters) { var i = 0 @@ -88,7 +88,7 @@ object HashBenchmark extends BenchmarkBase { } val getHiveHashCode = UnsafeProjection.create(new HiveHash(attrs) :: Nil, attrs) - benchmark.addCase("codegen HiveHash version") { _: Int => + benchmark.addCase("codegen HiveHash version") { (_: Int) => var sum = 0 for (_ <- 0L until iters) { var i = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 1baac88bf2d00..7b446c7e54138 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -47,7 +47,7 @@ object HashByteArrayBenchmark extends BenchmarkBase { val benchmark = new Benchmark( "Hash byte arrays with length " + length, iters * numArrays.toLong, output = output) - benchmark.addCase("Murmur3_x86_32") { _: Int => + benchmark.addCase("Murmur3_x86_32") { (_: Int) => var sum = 0L for (_ <- 0L until iters) { var i = 0 @@ -58,7 +58,7 @@ object HashByteArrayBenchmark extends BenchmarkBase { } } - benchmark.addCase("xxHash 64-bit") { _: Int => + benchmark.addCase("xxHash 64-bit") { (_: Int) => var sum = 0L for (_ <- 0L until iters) { var i = 0 @@ -69,7 +69,7 @@ object HashByteArrayBenchmark extends BenchmarkBase { } } - benchmark.addCase("HiveHasher") { _: Int => + benchmark.addCase("HiveHasher") { (_: Int) => var sum = 0L for (_ <- 0L until iters) { var i = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index eab4ddc666be4..6176bcf841fbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1244,9 +1244,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { Option(ExpressionEncoder[String]().resolveAndBind()) :: Nil, udfDeterministic = false) - Seq(reflect, udf).foreach { e: Expression => - val plan = Sort(Seq(e.asc), false, testRelation) - val projected = Alias(e, "_nondeterministic")() + Seq(reflect, udf).foreach { expression => + val plan = Sort(Seq(expression.asc), false, testRelation) + val projected = Alias(expression, "_nondeterministic")() val expect = Project(testRelation.output, Sort(Seq(projected.toAttribute.asc), false, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index 0792c1657456a..d1d77414e9ccc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -529,7 +529,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { } test("greatest/least cast") { - for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + for (operator <- Seq(Greatest(_), Least(_))) { ruleTest(AnsiTypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) @@ -1021,7 +1021,8 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { test("SPARK-35937: GetDateFieldOperations") { val ts = Literal(Timestamp.valueOf("2021-01-01 01:30:00")) Seq( - DayOfYear, Year, YearOfWeek, Quarter, Month, DayOfMonth, DayOfWeek, WeekDay, WeekOfYear + DayOfYear(_), Year(_), YearOfWeek(_), Quarter(_), Month(_), DayOfMonth(_), + DayOfWeek(_), WeekDay(_), WeekOfYear(_) ).foreach { operation => ruleTest( AnsiTypeCoercion.GetDateFieldOperations, operation(ts), operation(Cast(ts, DateType))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 95e118a30771c..0068265ee7c16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -688,7 +688,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer } test("check types for Greatest/Least") { - for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + for (operator <- Seq(Greatest(_), Least(_))) { val expr1 = operator(Seq($"booleanField")) assertErrorForWrongNumParameters( expr = expr1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 330252d26dc56..dfd7414e82560 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -435,12 +435,13 @@ abstract class TypeCoercionSuiteBase extends AnalysisTest { val timestampNTZLiteral = Literal(LocalDateTime.parse("2021-01-01T00:00:00")) val timestampLiteral = Literal(Timestamp.valueOf("2021-01-01 00:00:00")) Seq( - EqualTo, - EqualNullSafe, - GreaterThan, - GreaterThanOrEqual, - LessThan, - LessThanOrEqual).foreach { op => + EqualTo(_, _), + EqualNullSafe(_, _), + GreaterThan(_, _), + GreaterThanOrEqual(_, _), + LessThan(_, _), + LessThanOrEqual(_, _) + ).foreach { op => ruleTest(rule, op(dateLiteral, timestampNTZLiteral), op(Cast(dateLiteral, TimestampNTZType), timestampNTZLiteral)) @@ -1114,7 +1115,7 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { } test("greatest/least cast") { - for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + for (operator <- Seq(Greatest(_), Least(_))) { ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 89f0b95f5c18f..e74807c7efc32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -483,7 +483,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("Remainder/Pmod: exception should contain SQL text context") { - Seq(("%", Remainder), ("pmod", Pmod)).foreach { case (symbol, exprBuilder) => + Seq(("%", Remainder(_, _, _)), ("pmod", Pmod(_, _, _))).foreach { case (symbol, exprBuilder) => val query = s"1L $symbol 0L" val o = Origin( line = Some(1), @@ -637,7 +637,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } DataTypeTestUtils.ordered.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) + checkConsistencyBetweenInterpretedAndCodegen(Least(_), dt, 2) } val least = Least(Seq( @@ -698,7 +698,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } DataTypeTestUtils.ordered.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) + checkConsistencyBetweenInterpretedAndCodegen(Greatest(_), dt, 2) } val greatest = Greatest(Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala index 63602d04b5c79..2962e991998bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -49,7 +49,7 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot(_), dt) } } @@ -76,7 +76,7 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd(_, _), dt, dt) } } @@ -103,7 +103,7 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr(_, _), dt, dt) } } @@ -130,7 +130,7 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor(_, _), dt, dt) } } @@ -231,7 +231,11 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "upper" -> "8", "invalidValue" -> "16")) DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegenAllowingException(BitwiseGet, dt, IntegerType) + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + BitwiseGet(_, _), + dt, + IntegerType + ) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index e0d3a176b1a43..aa8594bb7b3d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -199,11 +199,17 @@ class CanonicalizeSuite extends SparkFunSuite { } test("SPARK-40362: Commutative operator under BinaryComparison") { - Seq(EqualTo, EqualNullSafe, GreaterThan, LessThan, GreaterThanOrEqual, LessThanOrEqual) - .foreach { bc => - assert(bc(Multiply($"a", $"b"), Literal(10)).semanticEquals( - bc(Multiply($"b", $"a"), Literal(10)))) - } + Seq( + EqualTo(_, _), + EqualNullSafe(_, _), + GreaterThan(_, _), + LessThan(_, _), + GreaterThanOrEqual(_, _), + LessThanOrEqual(_, _) + ).foreach { bc => + assert(bc(Multiply($"a", $"b"), Literal(10)).semanticEquals( + bc(Multiply($"b", $"a"), Literal(10)))) + } } test("SPARK-40903: Only reorder decimal Add when the result data type is not changed") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index f4c71a1056939..dfb713ca6c855 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -67,7 +67,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toString, StringType) DataTypeTestUtils.propertyCheckSupported.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(If, BooleanType, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(If(_, _, _), BooleanType, dt, dt) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 2ddddad7a2942..4101332c1077c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -94,7 +94,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(DayOfYear(Cast(Literal("1582-10-15 13:10:15"), DateType)), 288) checkEvaluation(DayOfYear(Cast(Literal("1582-10-04 13:10:15"), DateType)), 277) - checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) + checkConsistencyBetweenInterpretedAndCodegen(DayOfYear(_), DateType) } test("Year") { @@ -116,7 +116,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } checkEvaluation(Year(Cast(Literal("1582-01-01 13:10:15"), DateType)), 1582) checkEvaluation(Year(Cast(Literal("1581-12-31 13:10:15"), DateType)), 1581) - checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) + checkConsistencyBetweenInterpretedAndCodegen(Year(_), DateType) } test("Quarter") { @@ -139,7 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Quarter(Cast(Literal("1582-10-01 13:10:15"), DateType)), 4) checkEvaluation(Quarter(Cast(Literal("1582-09-30 13:10:15"), DateType)), 3) - checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) + checkConsistencyBetweenInterpretedAndCodegen(Quarter(_), DateType) } test("Month") { @@ -163,7 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - checkConsistencyBetweenInterpretedAndCodegen(Month, DateType) + checkConsistencyBetweenInterpretedAndCodegen(Month(_), DateType) } test("Day / DayOfMonth") { @@ -187,7 +187,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } - checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth, DateType) + checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth(_), DateType) } test("Seconds") { @@ -217,7 +217,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } Seq(TimestampType, TimestampNTZType).foreach { dt => checkConsistencyBetweenInterpretedAndCodegen( - (child: Expression) => Second(child, timeZoneId), dt) + child => Second(child, timeZoneId), dt) } } } @@ -233,7 +233,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Calendar.SATURDAY) checkEvaluation(DayOfWeek(Literal(new Date(toMillis("1582-10-15 13:10:15")))), Calendar.FRIDAY) - checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) + checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek(_), DateType) } test("WeekDay") { @@ -244,7 +244,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, UTC_OPT)), 4) checkEvaluation(WeekDay(Literal(new Date(toMillis("2017-05-27 13:10:15")))), 5) checkEvaluation(WeekDay(Literal(new Date(toMillis("1582-10-15 13:10:15")))), 4) - checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay(_), DateType) } test("WeekOfYear") { @@ -255,7 +255,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType, UTC_OPT)), 18) checkEvaluation(WeekOfYear(Cast(Literal("1582-10-15 13:10:15"), DateType, UTC_OPT)), 41) checkEvaluation(WeekOfYear(Cast(Literal("1582-10-04 13:10:15"), DateType, UTC_OPT)), 40) - checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) + checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear(_), DateType) } test("MonthName") { @@ -266,7 +266,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthName(Cast(Literal("2011-05-06"), DateType, UTC_OPT)), "May") checkEvaluation(MonthName(Literal(new Date(toMillis("2017-01-27 13:10:15")))), "Jan") checkEvaluation(MonthName(Literal(new Date(toMillis("1582-12-15 13:10:15")))), "Dec") - checkConsistencyBetweenInterpretedAndCodegen(MonthName, DateType) + checkConsistencyBetweenInterpretedAndCodegen(MonthName(_), DateType) } test("DayName") { @@ -277,7 +277,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(DayName(Cast(Literal("2011-05-06"), DateType, UTC_OPT)), "Fri") checkEvaluation(DayName(Cast(Literal(LocalDate.parse("2017-05-27")), DateType, UTC_OPT)), "Sat") checkEvaluation(DayName(Cast(Literal(LocalDate.parse("1582-10-15")), DateType, UTC_OPT)), "Fri") - checkConsistencyBetweenInterpretedAndCodegen(DayName, DateType) + checkConsistencyBetweenInterpretedAndCodegen(DayName(_), DateType) } test("DateFormat") { @@ -410,9 +410,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) checkEvaluation( DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) - checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, ByteType) - checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, ShortType) - checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd(_, _), DateType, ByteType) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd(_, _), DateType, ShortType) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd(_, _), DateType, IntegerType) } test("date add interval") { @@ -471,9 +471,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) checkEvaluation( DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) - checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, ByteType) - checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, ShortType) - checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(DateSub(_, _), DateType, ByteType) + checkConsistencyBetweenInterpretedAndCodegen(DateSub(_, _), DateType, ShortType) + checkConsistencyBetweenInterpretedAndCodegen(DateSub(_, _), DateType, IntegerType) } test("time_add") { @@ -690,7 +690,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) - checkConsistencyBetweenInterpretedAndCodegen(LastDay, DateType) + checkConsistencyBetweenInterpretedAndCodegen(LastDay(_), DateType) } test("next_day") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index dddc33aa43580..359dd33bfefe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -46,7 +46,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) - checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) + checkConsistencyBetweenInterpretedAndCodegen(Md5(_), BinaryType) } test("sha1") { @@ -57,7 +57,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))), "da39a3ee5e6b4b0d3255bfef95601890afd80709") - checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) + checkConsistencyBetweenInterpretedAndCodegen(Sha1(_), BinaryType) } test("sha2") { @@ -86,7 +86,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) - checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) + checkConsistencyBetweenInterpretedAndCodegen(Crc32(_), BinaryType) } def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 823a6d2ce8675..f17eceb166fc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -200,22 +200,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("e") { - testLeaf(EulerNumber, math.E) + testLeaf(EulerNumber.apply, math.E) } test("pi") { - testLeaf(Pi, math.Pi) + testLeaf(Pi.apply, math.Pi) } test("sin") { - testUnary(Sin, math.sin) - checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) + testUnary(Sin.apply, math.sin) + checkConsistencyBetweenInterpretedAndCodegen(Sin(_), DoubleType) } test("csc") { - def f: (Double) => Double = (x: Double) => 1 / math.sin(x) - testUnary(Csc, f) - checkConsistencyBetweenInterpretedAndCodegen(Csc, DoubleType) + def f: Double => Double = (x: Double) => 1 / math.sin(x) + testUnary(Csc(_), f) + checkConsistencyBetweenInterpretedAndCodegen(Csc(_), DoubleType) val nullLit = Literal.create(null, NullType) val intNullLit = Literal.create(null, IntegerType) val intLit = Literal.create(1, IntegerType) @@ -227,19 +227,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("asin") { - testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) + testUnary(Asin(_), math.asin, (-10 to 10).map(_ * 0.1)) + testUnary(Asin(_), math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Asin(_), DoubleType) } test("sinh") { - testUnary(Sinh, math.sinh) - checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) + testUnary(Sinh(_), math.sinh) + checkConsistencyBetweenInterpretedAndCodegen(Sinh(_), DoubleType) } test("asinh") { - testUnary(Asinh, (x: Double) => StrictMath.log(x + math.sqrt(x * x + 1.0))) - checkConsistencyBetweenInterpretedAndCodegen(Asinh, DoubleType) + testUnary(Asinh(_), (x: Double) => StrictMath.log(x + math.sqrt(x * x + 1.0))) + checkConsistencyBetweenInterpretedAndCodegen(Asinh(_), DoubleType) checkEvaluation(Asinh(Double.NegativeInfinity), Double.NegativeInfinity) @@ -250,14 +250,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cos") { - testUnary(Cos, math.cos) - checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) + testUnary(Cos(_), math.cos) + checkConsistencyBetweenInterpretedAndCodegen(Cos(_), DoubleType) } test("sec") { - def f: (Double) => Double = (x: Double) => 1 / math.cos(x) - testUnary(Sec, f) - checkConsistencyBetweenInterpretedAndCodegen(Sec, DoubleType) + def f: Double => Double = (x: Double) => 1 / math.cos(x) + testUnary(Sec(_), f) + checkConsistencyBetweenInterpretedAndCodegen(Sec(_), DoubleType) val nullLit = Literal.create(null, NullType) val intNullLit = Literal.create(null, IntegerType) val intLit = Literal.create(1, IntegerType) @@ -269,19 +269,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("acos") { - testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) + testUnary(Acos(_), math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos(_), math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acos(_), DoubleType) } test("cosh") { - testUnary(Cosh, math.cosh) - checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) + testUnary(Cosh(_), math.cosh) + checkConsistencyBetweenInterpretedAndCodegen(Cosh(_), DoubleType) } test("acosh") { - testUnary(Acosh, (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0))) - checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) + testUnary(Acosh(_), (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0))) + checkConsistencyBetweenInterpretedAndCodegen(Cosh(_), DoubleType) val nullLit = Literal.create(null, NullType) val doubleNullLit = Literal.create(null, DoubleType) @@ -290,14 +290,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("tan") { - testUnary(Tan, math.tan) - checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) + testUnary(Tan(_), math.tan) + checkConsistencyBetweenInterpretedAndCodegen(Tan(_), DoubleType) } test("cot") { - def f: (Double) => Double = (x: Double) => 1 / math.tan(x) - testUnary(Cot, f) - checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType) + def f: Double => Double = (x: Double) => 1 / math.tan(x) + testUnary(Cot(_), f) + checkConsistencyBetweenInterpretedAndCodegen(Cot(_), DoubleType) val nullLit = Literal.create(null, NullType) val intNullLit = Literal.create(null, IntegerType) val intLit = Literal.create(1, IntegerType) @@ -309,19 +309,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("atan") { - testUnary(Atan, math.atan) - checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) + testUnary(Atan(_), math.atan) + checkConsistencyBetweenInterpretedAndCodegen(Atan(_), DoubleType) } test("tanh") { - testUnary(Tanh, math.tanh) - checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) + testUnary(Tanh(_), math.tanh) + checkConsistencyBetweenInterpretedAndCodegen(Tanh(_), DoubleType) } test("atanh") { // SPARK-28519: more accurate express for 1/2 * ln((1 + x) / (1 - x)) - testUnary(Atanh, (x: Double) => 0.5 * (StrictMath.log1p(x) - StrictMath.log1p(-x))) - checkConsistencyBetweenInterpretedAndCodegen(Atanh, DoubleType) + testUnary(Atanh(_), (x: Double) => 0.5 * (StrictMath.log1p(x) - StrictMath.log1p(-x))) + checkConsistencyBetweenInterpretedAndCodegen(Atanh(_), DoubleType) val nullLit = Literal.create(null, NullType) val doubleNullLit = Literal.create(null, DoubleType) @@ -330,18 +330,18 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("toDegrees") { - testUnary(ToDegrees, math.toDegrees) - checkConsistencyBetweenInterpretedAndCodegen(ToDegrees, DoubleType) + testUnary(ToDegrees(_), math.toDegrees) + checkConsistencyBetweenInterpretedAndCodegen(ToDegrees(_), DoubleType) } test("toRadians") { - testUnary(ToRadians, math.toRadians) - checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) + testUnary(ToRadians(_), math.toRadians) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians(_), DoubleType) } test("cbrt") { - testUnary(Cbrt, math.cbrt) - checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) + testUnary(Cbrt(_), math.cbrt) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt(_), DoubleType) } def checkDataTypeAndCast(expression: Expression): Expression = expression match { @@ -360,13 +360,13 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("ceil") { - testUnary(Ceil, (d: Double) => math.ceil(d).toLong) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) + testUnary(Ceil(_), (d: Double) => math.ceil(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Ceil(_), DoubleType) - testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + testUnary(Ceil(_), (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Ceil(_), DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil(_), DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil(_), DecimalType(5, 0)) val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f @@ -390,13 +390,13 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("floor") { - testUnary(Floor, (d: Double) => math.floor(d).toLong) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) + testUnary(Floor(_), (d: Double) => math.floor(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Floor(_), DoubleType) - testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + testUnary(Floor(_), (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Floor(_), DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Floor(_), DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Floor(_), DecimalType(5, 0)) val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f @@ -426,49 +426,49 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) - checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(Factorial(_), IntegerType) } test("rint") { - testUnary(Rint, math.rint) - checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) + testUnary(Rint(_), math.rint) + checkConsistencyBetweenInterpretedAndCodegen(Rint(_), DoubleType) } test("exp") { - testUnary(Exp, StrictMath.exp) - checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) + testUnary(Exp(_), StrictMath.exp) + checkConsistencyBetweenInterpretedAndCodegen(Exp(_), DoubleType) } test("expm1") { - testUnary(Expm1, StrictMath.expm1) - checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) + testUnary(Expm1(_), StrictMath.expm1) + checkConsistencyBetweenInterpretedAndCodegen(Expm1(_), DoubleType) } test("signum") { - testUnary[Double, Double](Signum, math.signum) - checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) + testUnary[Double, Double](Signum(_), math.signum) + checkConsistencyBetweenInterpretedAndCodegen(Signum(_), DoubleType) } test("log") { - testUnary(Log, StrictMath.log, (1 to 20).map(_ * 0.1)) - testUnary(Log, StrictMath.log, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) + testUnary(Log(_), StrictMath.log, (1 to 20).map(_ * 0.1)) + testUnary(Log(_), StrictMath.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log(_), DoubleType) } test("log10") { - testUnary(Log10, StrictMath.log10, (1 to 20).map(_ * 0.1)) - testUnary(Log10, StrictMath.log10, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) + testUnary(Log10(_), StrictMath.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10(_), StrictMath.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log10(_), DoubleType) } test("log1p") { - testUnary(Log1p, StrictMath.log1p, (0 to 20).map(_ * 0.1)) - testUnary(Log1p, StrictMath.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) + testUnary(Log1p(_), StrictMath.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p(_), StrictMath.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log1p(_), DoubleType) } test("bin") { - testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) + testUnary(Bin(_), java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) val row = create_row(null, 12L, 123L, 1234L, -123L) val l1 = $"a".long.at(0) @@ -486,30 +486,31 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) - checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) + checkConsistencyBetweenInterpretedAndCodegen(Bin(_), LongType) } test("log2") { - def f: (Double) => Double = (x: Double) => StrictMath.log(x) / StrictMath.log(2) - testUnary(Log2, f, (1 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) + def f: Double => Double = (x: Double) => StrictMath.log(x) / StrictMath.log(2) + testUnary(Log2(_), f, (1 to 20).map(_ * 0.1)) + testUnary(Log2(_), f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log2(_), DoubleType) } test("sqrt") { - testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) - testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) + testUnary(Sqrt(_), math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt(_), math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) - checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt(_), DoubleType) } test("pow") { - testBinary(Pow, StrictMath.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, StrictMath.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) + testBinary(Pow(_, _), StrictMath.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary( + Pow(_, _), StrictMath.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow(_, _), DoubleType, DoubleType) } test("shift left") { @@ -531,8 +532,8 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) - checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft(_, _), IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft(_, _), LongType, IntegerType) } test("shift right") { @@ -554,8 +555,8 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight(_, _), IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight(_, _), LongType, IntegerType) } test("shift right unsigned") { @@ -585,8 +586,8 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), negativeLong >>> negativeInt) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned(_, _), IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned(_, _), LongType, IntegerType) } test("hex") { @@ -627,13 +628,13 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hypot") { - testBinary(Hypot, math.hypot) - checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) + testBinary(Hypot(_, _), math.hypot) + checkConsistencyBetweenInterpretedAndCodegen(Hypot(_, _), DoubleType, DoubleType) } test("atan2") { - testBinary(Atan2, math.atan2) - checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) + testBinary(Atan2(_, _), math.atan2) + checkConsistencyBetweenInterpretedAndCodegen(Atan2(_, _), DoubleType, DoubleType) } test("binary log") { @@ -665,7 +666,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) - checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm(_, _), DoubleType, DoubleType) } test("round/bround/floor/ceil") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 4a7bf807d1de9..e3839a9819e89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -83,19 +83,19 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer) } - checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Not(_), BooleanType) } test("AND, OR, EqualTo, EqualNullSafe consistency check") { - checkConsistencyBetweenInterpretedAndCodegen(And, BooleanType, BooleanType) - checkConsistencyBetweenInterpretedAndCodegen(Or, BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(And(_, _), BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Or(_, _), BooleanType, BooleanType) DataTypeTestUtils.propertyCheckSupported.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt) - checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualTo(_, _), dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe(_, _), dt, dt) } } - booleanLogicTest("AND", And, + booleanLogicTest("AND", And(_, _), (true, true, true) :: (true, false, false) :: (true, null, null) :: @@ -106,7 +106,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, false, false) :: (null, null, null) :: Nil) - booleanLogicTest("OR", Or, + booleanLogicTest("OR", Or(_, _), (true, true, true) :: (true, false, true) :: (true, null, true) :: @@ -117,7 +117,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, false, null) :: (null, null, null) :: Nil) - booleanLogicTest("=", EqualTo, + booleanLogicTest("=", EqualTo(_, _), (true, true, true) :: (true, false, false) :: (true, null, null) :: @@ -395,10 +395,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(LessThan, dt, dt) - checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual, dt, dt) - checkConsistencyBetweenInterpretedAndCodegen(GreaterThan, dt, dt) - checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThan(_, _), dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual(_, _), dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThan(_, _), dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual(_, _), dt, dt) } } @@ -463,11 +463,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(op(nullNullType, nullNullType), null) } - nullTest(LessThan) - nullTest(LessThanOrEqual) - nullTest(GreaterThan) - nullTest(GreaterThanOrEqual) - nullTest(EqualTo) + nullTest(LessThan(_, _)) + nullTest(LessThanOrEqual(_, _)) + nullTest(GreaterThan(_, _)) + nullTest(GreaterThanOrEqual(_, _)) + nullTest(EqualTo(_, _)) checkEvaluation(EqualNullSafe(normalInt, nullInt), false) checkEvaluation(EqualNullSafe(nullInt, normalInt), false) @@ -605,12 +605,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected) } - checkExpr(EqualTo, Double.NaN, Double.NaN, true) - checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false) - checkExpr(EqualTo, 0.0, -0.0, true) - checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true) - checkExpr(GreaterThan, Double.NaN, Double.NaN, false) - checkExpr(GreaterThan, 0.0, -0.0, false) + checkExpr(EqualTo(_, _), Double.NaN, Double.NaN, true) + checkExpr(EqualTo(_, _), Double.NaN, Double.PositiveInfinity, false) + checkExpr(EqualTo(_, _), 0.0, -0.0, true) + checkExpr(GreaterThan(_, _), Double.NaN, Double.PositiveInfinity, true) + checkExpr(GreaterThan(_, _), Double.NaN, Double.NaN, false) + checkExpr(GreaterThan(_, _), 0.0, -0.0, false) } test("SPARK-32110: compare special double/float values in struct") { @@ -653,12 +653,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected) } - checkExpr(EqualTo, Double.NaN, Double.NaN, true) - checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false) - checkExpr(EqualTo, 0.0, -0.0, true) - checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true) - checkExpr(GreaterThan, Double.NaN, Double.NaN, false) - checkExpr(GreaterThan, 0.0, -0.0, false) + checkExpr(EqualTo(_, _), Double.NaN, Double.NaN, true) + checkExpr(EqualTo(_, _), Double.NaN, Double.PositiveInfinity, false) + checkExpr(EqualTo(_, _), 0.0, -0.0, true) + checkExpr(GreaterThan(_, _), Double.NaN, Double.PositiveInfinity, true) + checkExpr(GreaterThan(_, _), Double.NaN, Double.NaN, false) + checkExpr(GreaterThan(_, _), 0.0, -0.0, false) } test("SPARK-36792: InSet should handle Double.NaN and Float.NaN") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala index 82e8277b42b9a..efde76c6d04bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala @@ -64,7 +64,7 @@ class HistogramNumericSuite extends SparkFunSuite with SQLHelper { test("class NumericHistogram, basic operations") { val valueCount = 5 - Seq(3, 5).foreach { nBins: Int => + Seq(3, 5).foreach { (nBins: Int) => val buffer = new NumericHistogram() buffer.allocate(nBins) (1 to valueCount).grouped(nBins).foreach { group => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index 252bcea76007a..b13002e06c1c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -210,12 +210,12 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { )) } - testExpr(XPathBoolean) - testExpr(XPathShort) - testExpr(XPathInt) - testExpr(XPathLong) - testExpr(XPathFloat) - testExpr(XPathDouble) - testExpr(XPathString) + testExpr(XPathBoolean(_, _)) + testExpr(XPathShort(_, _)) + testExpr(XPathInt(_, _)) + testExpr(XPathLong(_, _)) + testExpr(XPathFloat(_, _)) + testExpr(XPathDouble(_, _)) + testExpr(XPathString(_, _)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index 6e1c7fc887d4e..789dcc76d0329 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -171,10 +171,10 @@ class ComputeCurrentTimeSuite extends PlanTest { } val numTimezones = ZoneId.SHORT_IDS.size - checkLiterals({ _: String => CurrentTimestamp() }, 1) - checkLiterals({ zoneId: String => LocalTimestamp(Some(zoneId)) }, numTimezones) - checkLiterals({ _: String => Now() }, 1) - checkLiterals({ zoneId: String => CurrentDate(Some(zoneId)) }, numTimezones) + checkLiterals((_: String) => CurrentTimestamp(), 1) + checkLiterals((zoneId: String) => LocalTimestamp(Some(zoneId)), numTimezones) + checkLiterals((_: String) => Now(), 1) + checkLiterals((zoneId: String) => CurrentDate(Some(zoneId)), numTimezones) } private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d8d58ea6aa903..92069ed858bff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -425,7 +425,7 @@ class ConstantFoldingSuite extends PlanTest { Optimize.execute(testRelation.select(ScalarSubquery(emptyRelation).as("o")).analyze), testRelation.select(nullIntLit.as("o")).analyze) - Seq(EqualTo, LessThan, GreaterThan).foreach { comparison => + Seq(EqualTo(_, _), LessThan(_, _), GreaterThan(_, _)).foreach { comparison => comparePlans( Optimize.execute(testRelation .select(comparison($"a", ScalarSubquery(emptyRelation)).as("o")).analyze), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index eaa651e62455e..91f87d5f27074 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -371,7 +371,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) test("replace nulls in lambda function of ArrayFilter") { - testHigherOrderFunc($"a", ArrayFilter, Seq(lv(Symbol("e")))) + testHigherOrderFunc($"a", ArrayFilter(_, _), Seq(lv(Symbol("e")))) } test("replace nulls in lambda function of ArrayExists") { @@ -390,7 +390,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { } test("replace nulls in lambda function of MapFilter") { - testHigherOrderFunc($"m", MapFilter, Seq(lv(Symbol("k")), lv(Symbol("v")))) + testHigherOrderFunc($"m", MapFilter(_, _), Seq(lv(Symbol("k")), lv(Symbol("v")))) } test("inability to replace nulls in arbitrary higher-order function") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index e93ec751a7f82..225a007b69dac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -213,7 +213,7 @@ class ParserUtilsSuite extends SparkFunSuite { } test("validate") { - val f1 = { ctx: ParserRuleContext => + val f1 = { (ctx: ParserRuleContext) => ctx.children != null && !ctx.children.isEmpty } val message = "ParserRuleContext should not be empty." diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RebaseDateTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RebaseDateTimeSuite.scala index 475c59b368233..fde4f56272dd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RebaseDateTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RebaseDateTimeSuite.scala @@ -209,14 +209,15 @@ class RebaseDateTimeSuite extends SparkFunSuite with Matchers with SQLHelper { .atZone(zid) .toInstant) var micros = start - do { + while ({ val rebased = rebaseGregorianToJulianMicros(TimeZone.getTimeZone(zid), micros) val rebasedAndOptimized = withDefaultTimeZone(zid) { rebaseGregorianToJulianMicros(micros) } assert(rebasedAndOptimized === rebased) micros += (MICROS_PER_DAY * 30 * (0.5 + Math.random())).toLong - } while (micros <= end) + micros <= end + }) () } } } @@ -229,14 +230,15 @@ class RebaseDateTimeSuite extends SparkFunSuite with Matchers with SQLHelper { val end = rebaseGregorianToJulianMicros( instantToMicros(LocalDateTime.of(2100, 1, 1, 0, 0, 0).atZone(zid).toInstant)) var micros = start - do { + while ({ val rebased = rebaseJulianToGregorianMicros(TimeZone.getTimeZone(zid), micros) val rebasedAndOptimized = withDefaultTimeZone(zid) { rebaseJulianToGregorianMicros(micros) } assert(rebasedAndOptimized === rebased) micros += (MICROS_PER_DAY * 30 * (0.5 + Math.random())).toLong - } while (micros <= end) + micros <= end + }) () } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala index 5c3a0d239ee69..04a77a16ae3a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/EnumTypeSetBenchmark.scala @@ -76,11 +76,11 @@ object EnumTypeSetBenchmark extends BenchmarkBase { val benchmark = new Benchmark(s"Test create $sizeLiteral Set", valuesPerIteration, output = output) - benchmark.addCase("Use HashSet") { _: Int => + benchmark.addCase("Use HashSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) {creatHashSetFunctions.apply()} } - benchmark.addCase("Use EnumSet") { _: Int => + benchmark.addCase("Use EnumSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) {creatEnumSetFunctions.apply()} } benchmark.run() @@ -99,13 +99,13 @@ object EnumTypeSetBenchmark extends BenchmarkBase { valuesPerIteration * capabilities.length, output = output) - benchmark.addCase("Use HashSet") { _: Int => + benchmark.addCase("Use HashSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) { capabilities.foreach(hashSet.contains) } } - benchmark.addCase("Use EnumSet") { _: Int => + benchmark.addCase("Use EnumSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) { capabilities.foreach(enumSet.contains) } @@ -126,13 +126,13 @@ object EnumTypeSetBenchmark extends BenchmarkBase { valuesPerIteration * capabilities.length, output = output) - benchmark.addCase("Use HashSet") { _: Int => + benchmark.addCase("Use HashSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) { capabilities.foreach(creatHashSetFunctions.apply().contains) } } - benchmark.addCase("Use EnumSet") { _: Int => + benchmark.addCase("Use EnumSet") { (_: Int) => for (_ <- 0L until valuesPerIteration) { capabilities.foreach(creatEnumSetFunctions.apply().contains) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index d6d397b94648d..35354b778a5e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -808,9 +808,8 @@ class InMemoryCustomDriverMetric extends CustomSumMetric { override def description(): String = "number of rows from driver" } -class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric { +class InMemoryCustomDriverTaskMetric(override val value: Long) extends CustomTaskMetric { override def name(): String = "number_of_rows_from_driver" - override def value(): Long = value } sealed trait Operation diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala index 81545afadd3f8..02a94f30313ec 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UserDefinedFunctionE2ETestSuite.scala @@ -129,7 +129,7 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession val result3 = Seq("a b c", "d e") .toDF("words") - .explode("words", "word") { word: String => + .explode("words", "word") { (word: String) => word.split(' ').toSeq } .select(col("word")) @@ -140,7 +140,7 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession val result4 = Seq("a b c", "d e") .toDF("words") - .explode("words", "word") { word: String => + .explode("words", "word") { (word: String) => word.split(' ').map(s => s -> s.head.toInt).toSeq } .select(col("word"), col("words")) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala index 3c181740eb435..451413e7cb751 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala @@ -169,7 +169,7 @@ object SparkConnectServerUtils { val jars = System .getProperty("java.class.path") .split(File.pathSeparatorChar) - .filter { e: String => + .filter { (e: String) => val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1) fileName.endsWith(".jar") && (fileName.startsWith("scalatest") || fileName.startsWith("scalactic") || diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index f3c13c9c2c4d8..4fc30bef5c72c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -170,14 +170,15 @@ class ExecutePlanResponseReattachableIterator( // If iter ended, but there was no ResultComplete, it means that there is more, // and we need to reattach. if (!hasNext && !resultComplete) { - do { + while ({ iter = None // unset iterator for new ReattachExecute to be called in _call_iter assert(!resultComplete) // shouldn't change... hasNext = callIter(_.hasNext()) // It's possible that the new iter will be empty, so we need to loop to get another. // Eventually, there will be a non empty iter, because there is always a // ResultComplete inserted by the server at the end of the stream. - } while (!hasNext) + !hasNext + }) () } hasNext } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index d3dae47f4c471..1df7834af9c3c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -26,6 +26,7 @@ import io.grpc.{ManagedChannel, StatusRuntimeException} import io.grpc.protobuf.StatusProto import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods +import org.json4s.jvalue2extractable import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, SparkConnectServiceGrpc, UserContext} @@ -177,7 +178,7 @@ private[client] object GrpcExceptionConverter { private def errorConstructor[T <: Throwable: ClassTag]( throwableCtr: ErrorParams => T): (String, ErrorParams => Throwable) = { - val className = implicitly[reflect.ClassTag[T]].runtimeClass.getName + val className = implicitly[ClassTag[T]].runtimeClass.getName (className, throwableCtr) } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala index 8c8472d780dbc..06e51c8f8c1e3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.client +import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, FiniteDuration} import scala.util.Random diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d79fb25ec1a0b..1e6bc129a79da 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -404,7 +404,7 @@ object ArrowSerializer { case (ArrayEncoder(element, _), v: ListVector) => val elementSerializer = serializerFor(element, v.getDataVector) - val toIterator = { array: Any => + val toIterator = { (array: Any) => array.asInstanceOf[Array[_]].iterator } new ArraySerializer(v, toIterator, elementSerializer) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 72f7065b44240..3efd331c9250e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -769,7 +769,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proto.Expression.ExpressionString.newBuilder().setExpression("id"))))) .build() val df = Dataset.ofRows(spark, transform(relation)) - df.foreachPartition { p: Iterator[Row] => + df.foreachPartition { (p: Iterator[Row]) => var previousValue: Int = -1 p.foreach { r => val v = r.getAs[Int](0) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala index 63d623cd2779b..b93e1941931d9 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala @@ -61,7 +61,7 @@ class SparkConnectWithSessionExtensionSuite extends SparkFunSuite { val spark = classic.SparkSession .builder() .master("local[1]") - .withExtensions(extension => extension.injectParser(MyParser)) + .withExtensions(extension => extension.injectParser(MyParser(_, _))) .getOrCreate() val read = proto.Read.newBuilder().build() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 1d83a46a278f7..3e927ab20b305 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -286,7 +286,7 @@ private[sql] class AvroSerializer( val numFields = catalystStruct.length val avroFields = avroStruct.getFields() val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable) - row: InternalRow => + (row: InternalRow) => val result = new Record(avroStruct) var i = 0 while (i < numFields) { @@ -340,7 +340,7 @@ private[sql] class AvroSerializer( }.toArray val numBranches = catalystStruct.length - row: InternalRow => { + (row: InternalRow) => { var idx = 0 var retVal: Any = null while (idx < numBranches && retVal == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 03ce8882f2fa5..14d4422c4e1fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -698,9 +698,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) private def convertToStructField(col: QualifiedColType): StructField = { val builder = new MetadataBuilder col.comment.foreach(builder.putString("comment", _)) - col.default.map { - value: String => builder.putString(DefaultCols.CURRENT_DEFAULT_COLUMN_METADATA_KEY, value) - } + col.default.map(value => + builder.putString(DefaultCols.CURRENT_DEFAULT_COLUMN_METADATA_KEY, value) + ) StructField(col.name.head, col.dataType, nullable = true, builder.build()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala index 885ddf4110cbb..bf25ee56c6836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala @@ -318,7 +318,7 @@ abstract class SimpleMetricsCachedBatchSerializer extends CachedBatchSerializer def ret(index: Int, cachedBatchIterator: Iterator[CachedBatch]): Iterator[CachedBatch] = { val partitionFilter = Predicate.create( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), + partitionFilters.reduceOption(And(_, _)).getOrElse(Literal(true)), cachedAttributes) partitionFilter.initialize(index) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 60156bff1fb71..a25d53d3fc405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -117,7 +117,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") val processRowWithoutSerde = if (!ioschema.schemaLess) { - prevLine: String => + (prevLine: String) => new GenericInternalRow( prevLine.split(outputRowFormat, -1).padTo(outputFieldWriters.size, null) .zip(outputFieldWriters) @@ -128,7 +128,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { // Here we split row string and choose first 2 values, if values's size less than 2, // we pad NULL value until 2 to make behavior same with hive. val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType) - prevLine: String => + (prevLine: String) => new GenericInternalRow( prevLine.split(outputRowFormat, -1).slice(0, 2).padTo(2, null) .map(kvWriter)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 31ab367c2d003..3d55b2b83712f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -313,7 +313,7 @@ trait FileSourceScanLike extends DataSourceScanExec { if (dynamicPartitionFilters.nonEmpty) { val startTime = System.nanoTime() // call the file index for the files matching all filters except dynamic partition filters - val predicate = dynamicPartitionFilters.reduce(And) + val predicate = dynamicPartitionFilters.reduce(And(_, _)) val partitionColumns = relation.partitionSchema val boundPredicate = Predicate.create(predicate.transform { case a: AttributeReference => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 4b561b813067e..61744a344dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -163,7 +163,7 @@ object SortPrefixUtils { } } } else { - _: InternalRow => emptyPrefix + (_: InternalRow) => emptyPrefix } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 5632e46c04374..a696ebcd1e99b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -93,7 +93,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition: Option[Expression] = - prunePushedDownFilters(filterPredicates).reduceLeftOption(And) + prunePushedDownFilters(filterPredicates).reduceLeftOption(And(_, _)) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index e8b70f94a7692..adad0d5471836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -38,9 +38,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * @param partitionSpecs The partition specs that defines the arrangement, requires at least one * partition. */ -case class AQEShuffleReadExec private( +case class AQEShuffleReadExec private[adaptive] ( child: SparkPlan, - partitionSpecs: Seq[ShufflePartitionSpec]) extends UnaryExecNode { + partitionSpecs: Seq[ShufflePartitionSpec] +) extends UnaryExecNode { assert(partitionSpecs.nonEmpty, s"${getClass.getSimpleName} requires at least one partition") // If this is to read shuffle files locally, then all partition specs should be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index 90f4ee70539ce..fe5b8a3832e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, MutableProjection, NamedExpression, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.metric.SQLMetric diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala index ed248ccca67a7..71bfdc859d728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala @@ -141,7 +141,7 @@ case class DescribeRelationJsonCommand( addKeyValueToMap("table_name", JString(ident.last), jsonMap) addKeyValueToMap("catalog_name", JString(ident.head), jsonMap) val namespace = ident.init.tail - addKeyValueToMap("namespace", JArray(namespace.map(JString).toList), jsonMap) + addKeyValueToMap("namespace", JArray(namespace.map(JString(_)).toList), jsonMap) if (namespace.nonEmpty) { addKeyValueToMap("schema_name", JString(namespace.last), jsonMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6eb81e6ec670b..266afb5d3a1bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -691,7 +691,7 @@ case class RepairTableCommand( // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) val jobConf = new JobConf(hadoopConf, this.getClass) val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - path: Path => { + (path: Path) => { val name = path.getName if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) { pathFilter == null || pathFilter.accept(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 092e6669338ee..64185512eff3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -292,7 +292,7 @@ case class AlterTableAddColumnsCommand( */ private def constantFoldCurrentDefaultsToExistDefaults( sparkSession: SparkSession, tableProvider: Option[String]): Seq[StructField] = { - colsToAdd.map { col: StructField => + colsToAdd.map { col => if (col.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { val schema = StructType(Array(col)) ResolveDefaultColumns.validateTableProviderForDefaultValue( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 882bc12a0d29b..ba0752bfcea36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} import org.apache.spark.sql.classic.ClassicConversions.castToImpl diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2f6588c3aac35..dc75e93dcf8b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -427,7 +427,7 @@ object DataSourceStrategy // Combines all Catalyst filter `Expression`s that are either not convertible to data source // `Filter`s or cannot be handled by `relation`. - val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And(_, _)) if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && @@ -686,7 +686,7 @@ object DataSourceStrategy case expressions.Not(child) => translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) - .map(sources.Not) + .map(sources.Not(_)) case other => val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 875c5dfc59638..38c06a7abdfca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -200,7 +200,7 @@ object DataSourceUtils extends PredicateHelper { def createDateRebaseFuncInRead( rebaseMode: LegacyBehaviorPolicy.Value, format: String): Int => Int = rebaseMode match { - case LegacyBehaviorPolicy.EXCEPTION => days: Int => + case LegacyBehaviorPolicy.EXCEPTION => (days: Int) => if (days < RebaseDateTime.lastSwitchJulianDay) { throw DataSourceUtils.newRebaseExceptionInRead(format) } @@ -212,7 +212,7 @@ object DataSourceUtils extends PredicateHelper { def createDateRebaseFuncInWrite( rebaseMode: LegacyBehaviorPolicy.Value, format: String): Int => Int = rebaseMode match { - case LegacyBehaviorPolicy.EXCEPTION => days: Int => + case LegacyBehaviorPolicy.EXCEPTION => (days: Int) => if (days < RebaseDateTime.lastSwitchGregorianDay) { throw DataSourceUtils.newRebaseExceptionInWrite(format) } @@ -224,7 +224,7 @@ object DataSourceUtils extends PredicateHelper { def createTimestampRebaseFuncInRead( rebaseSpec: RebaseSpec, format: String): Long => Long = rebaseSpec.mode match { - case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + case LegacyBehaviorPolicy.EXCEPTION => (micros: Long) => if (micros < RebaseDateTime.lastSwitchJulianTs) { throw DataSourceUtils.newRebaseExceptionInRead(format) } @@ -237,7 +237,7 @@ object DataSourceUtils extends PredicateHelper { def createTimestampRebaseFuncInWrite( rebaseMode: LegacyBehaviorPolicy.Value, format: String): Long => Long = rebaseMode match { - case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + case LegacyBehaviorPolicy.EXCEPTION => (micros: Long) => if (micros < RebaseDateTime.lastSwitchGregorianTs) { throw DataSourceUtils.newRebaseExceptionInWrite(format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index d3078740b819c..a71bc6b190431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -278,19 +278,19 @@ object FileFormat { * fields of the [[PartitionedFile]], and do have entries in the file's metadata map. */ val BASE_METADATA_EXTRACTORS: Map[String, PartitionedFile => Any] = Map( - FILE_PATH -> { pf: PartitionedFile => + FILE_PATH -> { (pf: PartitionedFile) => // Use `new Path(Path.toString)` as a form of canonicalization new Path(pf.filePath.toPath.toString).toUri.toString }, - FILE_NAME -> { pf: PartitionedFile => + FILE_NAME -> { (pf: PartitionedFile) => pf.filePath.toUri.getRawPath.split("/").lastOption.getOrElse("") }, - FILE_SIZE -> { pf: PartitionedFile => pf.fileSize }, - FILE_BLOCK_START -> { pf: PartitionedFile => pf.start }, - FILE_BLOCK_LENGTH -> { pf: PartitionedFile => pf.length }, + FILE_SIZE -> { (pf: PartitionedFile) => pf.fileSize }, + FILE_BLOCK_START -> { (pf: PartitionedFile) => pf.start }, + FILE_BLOCK_LENGTH -> { (pf: PartitionedFile) => pf.length }, // The modificationTime from the file has millisecond granularity, but the TimestampType for // `file_modification_time` has microsecond granularity. - FILE_MODIFICATION_TIME -> { pf: PartitionedFile => pf.modificationTime * 1000 } + FILE_MODIFICATION_TIME -> { (pf: PartitionedFile) => pf.modificationTime * 1000 } ) /** @@ -306,7 +306,7 @@ object FileFormat { file: PartitionedFile, metadataExtractors: Map[String, PartitionedFile => Any]): Literal = { val extractor = metadataExtractors.getOrElse(name, - { pf: PartitionedFile => pf.otherConstantMetadataColumnValues.get(name).orNull } + { (pf: PartitionedFile) => pf.otherConstantMetadataColumnValues.get(name).orNull } ) Literal(extractor.apply(file)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 0291a5fd28a72..f7db4d9853f92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -61,7 +61,7 @@ class FilePruningRunner(filters: Seq[Expression]) { metadataAttr.name != FileFormat.FILE_BLOCK_LENGTH case _ => false } - }.reduceOption(And) + }.reduceOption(And(_, _)) // - Retrieve all required metadata attributes and put them into a sequence // - Bind all file constant metadata attribute references to their respective index diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 6af6dc721a6f8..16c390fdd7587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -39,7 +39,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { filters: Seq[Expression], relation: LeafNode): Project = { val withFilter = if (filters.nonEmpty) { - val filterExpression = filters.reduceLeft(And) + val filterExpression = filters.reduceLeft(And(_, _)) Filter(filterExpression, relation) } else { relation @@ -78,7 +78,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files val filteredStats = - FilterEstimation(Filter(partitionKeyFilters.reduce(And), logicalRelation)).estimate + FilterEstimation(Filter(partitionKeyFilters.reduce(And(_, _)), logicalRelation)).estimate val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) => (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType)) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 5960cf8c38ced..76c1b3dbedbc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -327,7 +327,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) val withFilter = if (filters.nonEmpty) { - Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation) + Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And(_, _)), newRelation) } else { newRelation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 1b23fd1a5e829..8edcdb059d3e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -42,7 +42,7 @@ object SchemaPruning extends Rule[LogicalPlan] { plan transformDown { case op @ ScanOperation(projects, filtersStayUp, filtersPushDown, l @ LogicalRelationWithTable(hadoopFsRelation: HadoopFsRelation, _)) => - val allFilters = filtersPushDown.reduceOption(And).toSeq ++ filtersStayUp + val allFilters = filtersPushDown.reduceOption(And(_, _)).toSeq ++ filtersStayUp prunePhysicalColumns(l, projects, allFilters, hadoopFsRelation, (prunedDataSchema, prunedMetadataSchema) => { val prunedHadoopRelation = @@ -71,7 +71,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // If requestedRootFields includes a nested field, continue. Otherwise, // return op - if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + if (requestedRootFields.exists { (root: RootField) => !root.derivedFromAtt }) { val prunedDataSchema = if (canPruneDataSchema(hadoopFsRelation)) { pruneSchema(hadoopFsRelation.dataSchema, requestedRootFields) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala index 54c100282e2db..7a111c0b6eda7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala @@ -100,7 +100,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { val filterFuncs = filters.flatMap(filter => createFilterFunction(filter)) val maxLength = sparkSession.sessionState.conf.getConf(SOURCES_BINARY_FILE_MAX_LENGTH) - file: PartitionedFile => { + (file: PartitionedFile) => { val path = file.toPath val fs = path.getFileSystem(broadcastedHadoopConf.value.value) val status = fs.getFileStatus(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4bb1c187c45fa..03620cfcbb200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -127,7 +127,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { } yield Or(lhs, rhs) case Not(pred) => val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - childResultOptional.map(Not) + childResultOptional.map(Not(_)) case other => for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 565742671b9cd..ca731a6cf497f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -620,7 +620,7 @@ class ParquetFilters( } case sources.Not(pred) => val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - resultOptional.map(sources.Not) + resultOptional.map(sources.Not(_)) case other => if (createFilter(other).isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 232cb0935a260..2df468b622ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNames import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ @@ -688,7 +689,7 @@ private[sql] object DataSourceV2Strategy extends Logging { filters: Seq[Expression], scan: LeafExecNode, needsUnsafeConversion: Boolean): SparkPlan = { - val filterCondition = filters.reduceLeftOption(And) + val filterCondition = filters.reduceLeftOption(And(_, _)) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) if (withFilter.output != project || needsUnsafeConversion) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala index 8b8cdc06d398b..149cda7b27fd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala @@ -130,6 +130,6 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr val evaluatedFilterSet = ExpressionSet(evaluatedFilters) val predicates = splitConjunctivePredicates(cond) val remainingPredicates = predicates.filterNot(evaluatedFilterSet.contains) - remainingPredicates.reduceLeftOption(And).getOrElse(TrueLiteral) + remainingPredicates.reduceLeftOption(And(_, _)).getOrElse(TrueLiteral) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 34a1adcb6e091..6f55182d140a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.SortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 5f7e86cab5240..43e13c551093b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, GROUP_BY_EXPRS, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT} import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} @@ -94,7 +95,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { |Post-Scan Filters: ${MDC(POST_SCAN_FILTERS, postScanFilters.mkString(","))} """.stripMargin) - val filterCondition = postScanFilters.reduceLeftOption(And) + val filterCondition = postScanFilters.reduceLeftOption(And(_, _)) filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) } @@ -362,7 +363,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val normalizedProjects = DataSourceStrategy .normalizeExprs(project, sHolder.output) .asInstanceOf[Seq[NamedExpression]] - val allFilters = filtersPushDown.reduceOption(And).toSeq ++ filtersStayUp + val allFilters = filtersPushDown.reduceOption(And(_, _)).toSeq ++ filtersStayUp val normalizedFilters = DataSourceStrategy.normalizeExprs(allFilters, sHolder.output) val (scan, output) = PushDownUtils.pruneColumns( sHolder.builder, sHolder.relation, normalizedProjects, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8d49b1558d687..52bce593cc313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -1114,7 +1114,7 @@ private[joins] class SortMergeJoinScanner( } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - do { + while ({ if (streamedRowKey.anyNull) { advancedStreamed() } else { @@ -1123,7 +1123,8 @@ private[joins] class SortMergeJoinScanner( if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } - } while (streamedRow != null && bufferedRow != null && comp != 0) + streamedRow != null && bufferedRow != null && comp != 0 + }) () if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null @@ -1165,9 +1166,10 @@ private[joins] class SortMergeJoinScanner( // The buffered iterator could still contain matching rows, so we'll need to walk through // it until we either find matches or pass where they would be found. var comp = 1 - do { + while ({ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + comp > 0 && advancedBufferedToRowWithNullFreeJoinKey() + }) () if (comp == 0) { // We have found matches, so buffer them (this updates matchJoinKey) bufferMatchingRows() @@ -1233,12 +1235,13 @@ private[joins] class SortMergeJoinScanner( // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() - do { + while ({ if (!onlyBufferFirstMatch || bufferedMatches.isEmpty) { bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) } advancedBufferedToRowWithNullFreeJoinKey() - } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 + }) () } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index fd7ccb2189bff..6b3105363e1e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -218,7 +218,7 @@ object EvaluatePython { f.applyOrElse(input, { // all other unexpected type should be null, or we will have runtime exception // TODO(davies): we could improve this by try to cast the object to expected type - _: Any => null + (_: Any) => null }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index e7fc9c7391af4..b6f22a146700e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -143,7 +143,7 @@ class WindowInPandasEvaluatorFactory( private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = computeWindowBoundHelpers(factories.toImmutableArraySeq) - private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + private val isBounded = { (frameIndex: Int) => lowerBoundIndex(frameIndex) >= 0 } private val numFrames = factories.length private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala index 0f3ae844808e5..57f0f2875f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala @@ -83,7 +83,7 @@ class AsyncProgressTrackingMicroBatchExecution( } }) - override val offsetLog = new AsyncOffsetSeqLog( + override val offsetLog: AsyncOffsetSeqLog = new AsyncOffsetSeqLog( sparkSession, checkpointFile("offsets"), asyncWritesExecutorService, @@ -91,7 +91,7 @@ class AsyncProgressTrackingMicroBatchExecution( clock = triggerClock ) - override val commitLog = + override val commitLog: AsyncCommitLog = new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService) // perform quick validation to fail faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 465973cabe587..efa218b72b384 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -603,7 +603,7 @@ object FileStreamSource { // pathToCompare should have same depth as sourceGlobFilters.length var pathToCompare = baseArchivePathMinDepth var index = 0 - do { + while ({ // GlobFilter only matches against its name, not full path so it's safe to compare if (!sourceGlobFilters(index).accept(pathToCompare)) { matched = false @@ -611,7 +611,8 @@ object FileStreamSource { pathToCompare = pathToCompare.getParent index += 1 } - } while (matched && !pathToCompare.isRoot) + matched && !pathToCompare.isRoot + }) () matched } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 5fe3b0f82a0a4..17319e3091b0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 497e71070a09a..52abb4d780686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -118,10 +118,10 @@ object StreamingSymmetricHashJoinHelper extends Logging { } ( - leftConjuncts.reduceOption(And), - rightConjuncts.reduceOption(And), + leftConjuncts.reduceOption(And(_, _)), + rightConjuncts.reduceOption(And(_, _)), (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts) - .reduceOption(And) + .reduceOption(And(_, _)) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 6e0502e186597..6c638f26036ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala index 021a6fa1ecbdc..0c2a4758292b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -25,6 +25,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.{compact, render} +import org.json4s.jvalue2extractable import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a13c00ee20576..010196507c69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -147,9 +147,10 @@ class ContinuousExecution( } } - do { + while ({ runContinuous(sparkSessionForStream) - } while (state.updateAndGet(stateUpdate) == ACTIVE) + state.updateAndGet(stateUpdate) == ACTIVE + }) () stopSources() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index d67eb40fde2c2..83395182eb93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -24,6 +24,7 @@ import scala.util.Try import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods +import org.json4s.jvalue2extractable import org.apache.spark.sql.execution.streaming.MetadataVersionUtil import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index 97feb9b579af9..c367ff77ba090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala index 71df9dc65b419..0a7aae254468b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{StructType, TimestampType} import org.apache.spark.util.NextIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 66ab0006c4982..42014a143370d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index 7d13dbbe2a06a..24047e69c25ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -206,7 +206,7 @@ trait WindowEvaluatorFactoryBase { val factory = key match { // Frameless offset Frame case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => + (target: InternalRow) => new FrameLessOffsetWindowFunctionFrame( target, ordinal, @@ -218,7 +218,7 @@ trait WindowEvaluatorFactoryBase { offset, expr.nonEmpty) case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { + (target: InternalRow) => { new UnboundedOffsetWindowFunctionFrame( target, ordinal, @@ -231,7 +231,7 @@ trait WindowEvaluatorFactoryBase { expr.nonEmpty) } case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, expr) => - target: InternalRow => { + (target: InternalRow) => { new UnboundedPrecedingOffsetWindowFunctionFrame( target, ordinal, @@ -246,13 +246,13 @@ trait WindowEvaluatorFactoryBase { // Entire Partition Frame. case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => - target: InternalRow => { + (target: InternalRow) => { new UnboundedWindowFunctionFrame(target, processor) } // Growing Frame. case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => - target: InternalRow => { + (target: InternalRow) => { new UnboundedPrecedingWindowFunctionFrame( target, processor, @@ -261,7 +261,7 @@ trait WindowEvaluatorFactoryBase { // Shrinking Frame. case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => - target: InternalRow => { + (target: InternalRow) => { new UnboundedFollowingWindowFunctionFrame( target, processor, @@ -270,7 +270,7 @@ trait WindowEvaluatorFactoryBase { // Moving Frame. case ("AGGREGATE", frameType, lower, upper, _) => - target: InternalRow => { + (target: InternalRow) => { new SlidingWindowFunctionFrame( target, processor, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5b88eeefeca75..c685d035e605f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -107,7 +107,7 @@ class DataFrameSuite extends QueryTest val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") checkAnswer( - df.explode("words", "word") { word: String => word.split(" ").toSeq }.select($"word"), + df.explode("words", "word") { (word: String) => word.split(" ").toSeq }.select($"word"), Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 9d8aaf8d90e32..e0ae68c4a2b35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -136,7 +136,7 @@ class DatasetCacheSuite extends QueryTest } test("cache UDF result correctly") { - val expensiveUDF = udf({x: Int => Thread.sleep(2000); x}) + val expensiveUDF = udf { (x: Int) => Thread.sleep(2000); x } val df = spark.range(0, 2).toDF("a").repartition(1).withColumn("b", expensiveUDF($"a")) val df2 = df.agg(sum(df("b"))) @@ -154,7 +154,7 @@ class DatasetCacheSuite extends QueryTest } test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { - val udf1 = udf({x: Int => x + 1}) + val udf1 = udf { (x: Int) => x + 1 } val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) val df2 = df.agg(sum(df("b"))) @@ -184,7 +184,7 @@ class DatasetCacheSuite extends QueryTest } test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { - val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val expensiveUDF = udf { (x: Int) => Thread.sleep(5000); x } val df = spark.range(0, 5).toDF("a") val df1 = df.withColumn("b", expensiveUDF($"a")) val df2 = df1.groupBy($"a").agg(sum($"b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index e807ae306ce76..d2e2d5a53182f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -477,7 +477,7 @@ trait SQLQueryTestHelper extends Logging { def normalizeTestResults(output: String): String = { val strippedPythonErrors: String = { var traceback = false - output.split("\n").filter { line: String => + output.split("\n").filter { (line: String) => if (line == "Traceback (most recent call last):") { traceback = true } else if (!line.startsWith(" ")) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 575a4ae69d1a9..6dde4f3ac96d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -411,9 +411,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper withClue(clue) { testCase match { case _: AnalyzerTest => - readGoldenFileAndCompareResults(testCase.resultFile, outputs, AnalyzerOutput) + readGoldenFileAndCompareResults(testCase.resultFile, outputs, AnalyzerOutput(_, _, _)) case _ => - readGoldenFileAndCompareResults(testCase.resultFile, outputs, ExecutionOutput) + readGoldenFileAndCompareResults(testCase.resultFile, outputs, ExecutionOutput(_, _, _)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 8f0a62e210d85..bc6a59a4c9b57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -86,37 +86,37 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } test("inject analyzer rule") { - withSession(Seq(_.injectResolutionRule(MyRule))) { session => + withSession(Seq(_.injectResolutionRule(MyRule(_)))) { session => assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) } } test("inject post hoc resolution analyzer rule") { - withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session => + withSession(Seq(_.injectPostHocResolutionRule(MyRule(_)))) { session => assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) } } test("inject check analysis rule") { - withSession(Seq(_.injectCheckRule(MyCheckRule))) { session => + withSession(Seq(_.injectCheckRule(MyCheckRule(_)))) { session => assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) } } test("inject optimizer rule") { - withSession(Seq(_.injectOptimizerRule(MyRule))) { session => + withSession(Seq(_.injectOptimizerRule(MyRule(_)))) { session => assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) } } test("SPARK-33621: inject a pre CBO rule") { - withSession(Seq(_.injectPreCBORule(MyRule))) { session => + withSession(Seq(_.injectPreCBORule(MyRule(_)))) { session => assert(session.sessionState.optimizer.preCBORules.contains(MyRule(session))) } } test("inject spark planner strategy") { - withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session => + withSession(Seq(_.injectPlannerStrategy(MySparkStrategy(_)))) { session => assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) } } @@ -131,8 +131,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } test("inject multiple rules") { - withSession(Seq(_.injectOptimizerRule(MyRule), - _.injectPlannerStrategy(MySparkStrategy))) { session => + withSession(Seq(_.injectOptimizerRule(MyRule(_)), + _.injectPlannerStrategy(MySparkStrategy(_)))) { session => assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) } @@ -141,8 +141,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt test("inject stacked parsers") { val extension = create { extensions => extensions.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser) - extensions.injectParser(MyParser) - extensions.injectParser(MyParser) + extensions.injectParser(MyParser(_, _)) + extensions.injectParser(MyParser(_, _)) } withSession(extension) { session => val parser = MyParser(session, MyParser(session, CatalystSqlParser)) @@ -171,7 +171,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } test("inject custom hint rule") { - withSession(Seq(_.injectHintResolutionRule(MyHintRule))) { session => + withSession(Seq(_.injectHintResolutionRule(MyHintRule(_)))) { session => assert( session.range(1).hint("CONVERT_TO_EMPTY").logicalPlan.isInstanceOf[LocalRelation], "plan is expected to be a local relation" @@ -460,7 +460,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } test("SPARK-35673: user-defined hint and unrecognized hint in subquery") { - withSession(Seq(_.injectPostHocResolutionRule(MyHintRule))) { session => + withSession(Seq(_.injectPostHocResolutionRule(MyHintRule(_)))) { session => // unrecognized hint QueryTest.checkAnswer( session.sql( @@ -559,8 +559,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt test("custom aggregate hint") { // The custom hint allows us to replace the aggregate (without grouping keys) with just // Literal. - withSession(Seq(_.injectHintResolutionRule(CustomAggregateHintResolutionRule), - _.injectOptimizerRule(CustomAggregateRule))) { session => + withSession(Seq(_.injectHintResolutionRule(CustomAggregateHintResolutionRule(_)), + _.injectOptimizerRule(CustomAggregateRule(_)))) { session => val res = session.range(10).agg(max("id")).as("max_id") .hint("MAX_VALUE", "id", 10) .queryExecution.optimizedPlan @@ -572,8 +572,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt test("custom sort hint") { // The custom hint allows us to replace the sort with its input - withSession(Seq(_.injectHintResolutionRule(CustomSortHintResolutionRule), - _.injectOptimizerRule(CustomSortRule))) { session => + withSession(Seq(_.injectHintResolutionRule(CustomSortHintResolutionRule(_)), + _.injectOptimizerRule(CustomSortRule(_)))) { session => val res = session.range(10).sort("id") .hint("INPUT_SORTED") .queryExecution.optimizedPlan @@ -592,7 +592,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } } - withSession(Seq(_.injectHintResolutionRule(UnresolvedRelationRule))) { session => + withSession(Seq(_.injectHintResolutionRule(UnresolvedRelationRule(_)))) { session => withTable(session, "my_table") { session.sql("CREATE TABLE IF NOT EXISTS my_table (col1 INT)") ruleApplied = false @@ -1134,12 +1134,12 @@ case class MyColumnarRule(pre: Rule[SparkPlan], post: Rule[SparkPlan]) extends C class MyExtensions extends (SparkSessionExtensions => Unit) { def apply(e: SparkSessionExtensions): Unit = { - e.injectPlannerStrategy(MySparkStrategy) - e.injectResolutionRule(MyRule) - e.injectPostHocResolutionRule(MyRule) - e.injectCheckRule(MyCheckRule) - e.injectOptimizerRule(MyRule) - e.injectParser(MyParser) + e.injectPlannerStrategy(MySparkStrategy(_)) + e.injectResolutionRule(MyRule(_)) + e.injectPostHocResolutionRule(MyRule(_)) + e.injectCheckRule(MyCheckRule(_)) + e.injectOptimizerRule(MyRule(_)) + e.injectParser(MyParser(_, _)) e.injectFunction(MyExtensions.myFunction) e.injectColumnar(session => MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } @@ -1206,11 +1206,11 @@ object MyExtensions2 { class MyExtensions2 extends (SparkSessionExtensions => Unit) { def apply(e: SparkSessionExtensions): Unit = { - e.injectPlannerStrategy(MySparkStrategy2) - e.injectResolutionRule(MyRule2) - e.injectPostHocResolutionRule(MyRule2) - e.injectCheckRule(MyCheckRule2) - e.injectOptimizerRule(MyRule2) + e.injectPlannerStrategy(MySparkStrategy2(_)) + e.injectResolutionRule(MyRule2(_)) + e.injectPostHocResolutionRule(MyRule2(_)) + e.injectCheckRule(MyCheckRule2(_)) + e.injectOptimizerRule(MyRule2(_)) e.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser) e.injectFunction(MyExtensions2.myFunction) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index f5ca885b1ad63..0594e40b30695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -1040,7 +1040,7 @@ object TestingV2Source { class SimpleSinglePartitionSource extends TestingV2Source { - class MyScanBuilder extends SimpleScanBuilder { + class SimpleSinglePartitionSourceScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5)) } @@ -1048,14 +1048,14 @@ class SimpleSinglePartitionSource extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder() + new SimpleSinglePartitionSourceScanBuilder() } } } class ScanDefinedColumnarSupport extends TestingV2Source { - class MyScanBuilder(st: ColumnarSupportMode) extends SimpleScanBuilder { + class ScanDefinedColumnarSupportScanBuilder(st: ColumnarSupportMode) extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { throw new IllegalArgumentException("planInputPartitions must not be called") } @@ -1066,7 +1066,9 @@ class ScanDefinedColumnarSupport extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder(Scan.ColumnarSupportMode.valueOf(options.get("columnar"))) + new ScanDefinedColumnarSupportScanBuilder( + Scan.ColumnarSupportMode.valueOf(options.get("columnar")) + ) } } @@ -1077,7 +1079,7 @@ class ScanDefinedColumnarSupport extends TestingV2Source { // tests still pass. class SimpleDataSourceV2 extends TestingV2Source { - class MyScanBuilder extends SimpleScanBuilder { + class SimpleDataSourceV2ScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } @@ -1085,7 +1087,7 @@ class SimpleDataSourceV2 extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder() + new SimpleDataSourceV2ScanBuilder() } } } @@ -1251,7 +1253,7 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF class SchemaRequiredDataSource extends TableProvider { - class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { + class SchemaRequiredDataSourceScanBuilder(schema: StructType) extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 2)) } @@ -1274,7 +1276,7 @@ class SchemaRequiredDataSource extends TableProvider { override def schema(): StructType = userGivenSchema override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder(userGivenSchema) + new SchemaRequiredDataSourceScanBuilder(userGivenSchema) } } } @@ -1288,7 +1290,7 @@ class PartitionsRequiredDataSource extends SchemaRequiredDataSource { class ColumnarDataSourceV2 extends TestingV2Source { - class MyScanBuilder extends SimpleScanBuilder { + class ColumnarDataSourceV2ScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) @@ -1301,7 +1303,7 @@ class ColumnarDataSourceV2 extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder() + new ColumnarDataSourceV2ScanBuilder() } } } @@ -1353,7 +1355,7 @@ object ColumnarReaderFactory extends PartitionReaderFactory { class PartitionAwareDataSource extends TestingV2Source { - class MyScanBuilder extends SimpleScanBuilder + class PartitionAwareDataSourceScanBuilder extends SimpleScanBuilder with SupportsReportPartitioning { override def planInputPartitions(): Array[InputPartition] = { @@ -1373,14 +1375,14 @@ class PartitionAwareDataSource extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder() + new PartitionAwareDataSourceScanBuilder() } } } class OrderAndPartitionAwareDataSource extends PartitionAwareDataSource { - class MyScanBuilder( + private class OrderAndPartitionAwareDataSourceScanBuilder( val partitionKeys: Option[Seq[String]], val orderKeys: Seq[String]) extends SimpleScanBuilder @@ -1414,7 +1416,7 @@ class OrderAndPartitionAwareDataSource extends PartitionAwareDataSource { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder( + new OrderAndPartitionAwareDataSourceScanBuilder( Option(options.get("partitionKeys")).map(_.split(",").toImmutableArraySeq), Option(options.get("orderKeys")).map(_.split(",").toSeq).getOrElse(Seq.empty) ) @@ -1490,13 +1492,9 @@ class WritableDataSourceSupportsExternalMetadata extends SimpleWritableDataSourc */ class CustomSchemaAndPartitioningDataSource extends WritableDataSourceSupportsExternalMetadata { class TestTable( - schema: StructType, - partitioning: Array[Transform], - options: CaseInsensitiveStringMap) extends MyTable(options) { - override def schema(): StructType = schema - - override def partitioning(): Array[Transform] = partitioning - } + override val schema: StructType, + override val partitioning: Array[Transform], + options: CaseInsensitiveStringMap) extends MyTable(options) override def getTable( schema: StructType, @@ -1518,7 +1516,7 @@ class SupportsExternalMetadataWritableDataSource extends SimpleWritableDataSourc class ReportStatisticsDataSource extends SimpleWritableDataSource { - class MyScanBuilder extends SimpleScanBuilder + private class ReportStatisticsDataSourceScanBuilder extends SimpleScanBuilder with SupportsReportStatistics { override def estimateStatistics(): Statistics = { new Statistics { @@ -1536,7 +1534,7 @@ class ReportStatisticsDataSource extends SimpleWritableDataSource { override def getTable(options: CaseInsensitiveStringMap): Table = { new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new MyScanBuilder + new ReportStatisticsDataSourceScanBuilder } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FakeV2Provider.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FakeV2Provider.scala index 25d2d5a67d44e..e0a664c5fddac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FakeV2Provider.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FakeV2Provider.scala @@ -64,13 +64,9 @@ object FakeV2Provider { class FakeV2ProviderWithCustomSchema extends FakeV2Provider { class FakeTable( - schema: StructType, - partitioning: Array[Transform], + override val schema: StructType, + override val partitioning: Array[Transform], options: CaseInsensitiveStringMap) extends SimpleBatchTable { - override def schema(): StructType = schema - - override def partitioning(): Array[Transform] = partitioning - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index b2a46afb13b9e..bf2053fd0591e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.sql.{Date, Timestamp} import java.time.{Duration, Period} -import org.json4s.DefaultFormats +import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.scalatest.Assertions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 0a06e3cc3f6e7..a129e35e48642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -95,7 +95,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val df = spark .range(0, 1000, 1, numInputPartitions) @@ -130,7 +130,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test(s"determining the number of reducers: join operator$testNameNote") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -175,7 +175,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test: (SparkSession) => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -225,7 +225,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test: (SparkSession) => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -275,7 +275,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test(s"determining the number of reducers: plan already partitioned$testNameNote") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => try { spark.range(1000).write.bucketBy(30, "id").saveAsTable("t") // `df1` is hash partitioned by `id`. @@ -309,7 +309,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("SPARK-46590 adaptive query execution works correctly with broadcast join and union") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => import spark.implicits._ spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1KB") spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "10KB") @@ -337,7 +337,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("SPARK-46590 adaptive query execution works correctly with cartesian join and union") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => import spark.implicits._ spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "100B") @@ -370,7 +370,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => withSQLConf("spark.sql.exchange.reuse" -> "true") { val df = spark.range(0, 6, 1).selectExpr("id AS key", "id AS value") @@ -441,7 +441,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("Do not reduce the number of shuffle partition for repartition") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val ds = spark.range(3) val resultDf = ds.repartition(2, ds.col("id")).toDF() @@ -457,7 +457,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("Union two datasets with different pre-shuffle partition number") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val df1 = spark.range(3).join(spark.range(3), "id").toDF() val df2 = spark.range(3).groupBy().sum() @@ -477,7 +477,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("SPARK-34790: enable IO encryption in AQE partition coalescing") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => val ds = spark.range(0, 100, 1, numInputPartitions) val resultDf = ds.repartition(ds.col("id")) resultDf.collect() @@ -495,7 +495,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper } test("SPARK-51505: log empty partition number metrics") { - val test: SparkSession => Unit = { spark: SparkSession => + val test: SparkSession => Unit = { spark => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { val df = spark.range(0, 1000, 1, 5).withColumn("value", when(col("id") < 500, 0) .otherwise(1)).groupBy("value").agg("value" -> "sum") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 31b002a1e245d..0c845e25d5dda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -83,7 +83,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, numSpillThreshold) - benchmark.addCase("ArrayBuffer") { _: Int => + benchmark.addCase("ArrayBuffer") { (_: Int) => var sum = 0L for (_ <- 0L until iterations) { val array = new ArrayBuffer[UnsafeRow](initialSize) @@ -102,7 +102,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { } } - benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { (_: Int) => var sum = 0L for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray( @@ -133,7 +133,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows, output = output) - benchmark.addCase("UnsafeExternalSorter") { _: Int => + benchmark.addCase("UnsafeExternalSorter") { (_: Int) => var sum = 0L for (_ <- 0L until iterations) { val array = UnsafeExternalSorter.create( @@ -167,7 +167,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { } } - benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { (_: Int) => var sum = 0L for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala index f6bd881a82a02..b88ba98b18d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala @@ -58,7 +58,7 @@ object ByteArrayBenchmark extends BenchmarkBase { val dataLargeSlow = Seq.fill(count)( Array.tabulate(512) {i => if (i < 511) 0.toByte else 1.toByte}).toArray - def compareBinary(data: Array[Array[Byte]]) = { _: Int => + def compareBinary(data: Array[Array[Byte]]) = { (_: Int) => var sum = 0L for (_ <- 0L until iters) { var i = 0 @@ -80,7 +80,7 @@ object ByteArrayBenchmark extends BenchmarkBase { } def byteArrayEquals(iters: Long): Unit = { - def binaryEquals(inputs: Array[BinaryEqualInfo]) = { _: Int => + def binaryEquals(inputs: Array[BinaryEqualInfo]) = { (_: Int) => var res = false for (_ <- 0L until iters) { inputs.foreach { input => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ConstantColumnVectorBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ConstantColumnVectorBenchmark.scala index 078954f1a6023..d01b2fd3490c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ConstantColumnVectorBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ConstantColumnVectorBenchmark.scala @@ -93,20 +93,20 @@ object ConstantColumnVectorBenchmark extends BenchmarkBase { valuesPerIteration * batchSize, output = output) - benchmark.addCase("ConstantColumnVector") { _: Int => + benchmark.addCase("ConstantColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { ColumnVectorUtils.populate(constantColumnVector, row, 0) } } - benchmark.addCase("OnHeapColumnVector") { _: Int => + benchmark.addCase("OnHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { onHeapColumnVector.reset() populate(onHeapColumnVector, batchSize, row, 0) } } - benchmark.addCase("OffHeapColumnVector") { _: Int => + benchmark.addCase("OffHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { offHeapColumnVector.reset() populate(offHeapColumnVector, batchSize, row, 0) @@ -146,19 +146,19 @@ object ConstantColumnVectorBenchmark extends BenchmarkBase { valuesPerIteration * batchSize, output = output) - benchmark.addCase("ConstantColumnVector") { _: Int => + benchmark.addCase("ConstantColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { readValues(dataType, batchSize, constantColumnVector) } } - benchmark.addCase("OnHeapColumnVector") { _: Int => + benchmark.addCase("OnHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { readValues(dataType, batchSize, onHeapColumnVector) } } - benchmark.addCase("OffHeapColumnVector") { _: Int => + benchmark.addCase("OffHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { readValues(dataType, batchSize, offHeapColumnVector) } @@ -191,14 +191,14 @@ object ConstantColumnVectorBenchmark extends BenchmarkBase { valuesPerIteration * batchSize, output = output) - benchmark.addCase("ConstantColumnVector") { _: Int => + benchmark.addCase("ConstantColumnVector") { (_: Int) => ColumnVectorUtils.populate(constantColumnVector, row, 0) for (_ <- 0 until valuesPerIteration) { readValues(dataType, batchSize, constantColumnVector) } } - benchmark.addCase("OnHeapColumnVector") { _: Int => + benchmark.addCase("OnHeapColumnVector") { (_: Int) => onHeapColumnVector.reset() populate(onHeapColumnVector, batchSize, row, 0) for (_ <- 0 until valuesPerIteration) { @@ -206,7 +206,7 @@ object ConstantColumnVectorBenchmark extends BenchmarkBase { } } - benchmark.addCase("OffHeapColumnVector") { _: Int => + benchmark.addCase("OffHeapColumnVector") { (_: Int) => offHeapColumnVector.reset() populate(offHeapColumnVector, batchSize, row, 0) for (_ <- 0 until valuesPerIteration) { @@ -238,19 +238,19 @@ object ConstantColumnVectorBenchmark extends BenchmarkBase { valuesPerIteration * batchSize, output = output) - benchmark.addCase("ConstantColumnVector") { _: Int => + benchmark.addCase("ConstantColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { (0 until batchSize).foreach(constantColumnVector.isNullAt) } } - benchmark.addCase("OnHeapColumnVector") { _: Int => + benchmark.addCase("OnHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { (0 until batchSize).foreach(onHeapColumnVector.isNullAt) } } - benchmark.addCase("OffHeapColumnVector") { _: Int => + benchmark.addCase("OffHeapColumnVector") { (_: Int) => for (_ <- 0 until valuesPerIteration) { (0 until batchSize).foreach(offHeapColumnVector.isNullAt) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index a09a64d6a8fd3..a92318b223f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -56,7 +56,7 @@ object PrimitiveArrayBenchmark extends SqlBasedBenchmark { val primitiveIntArray = Array.fill[Int](count)(65535) val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS() dsInt.count() // force to build dataset - val intArray = { i: Int => + val intArray = { (i: Int) => var n = 0 var len = 0 while (n < iters) { @@ -67,7 +67,7 @@ object PrimitiveArrayBenchmark extends SqlBasedBenchmark { val primitiveDoubleArray = Array.fill[Double](count)(65535.0) val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS() dsDouble.count() // force to build dataset - val doubleArray = { i: Int => + val doubleArray = { (i: Int) => var n = 0 var len = 0 while (n < iters) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UDFBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UDFBenchmark.scala index 6ea65d863a964..36a84d4f0c770 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UDFBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UDFBenchmark.scala @@ -113,7 +113,7 @@ object UDFBenchmark extends SqlBasedBenchmark { .noop() } - val identityUDF = udf { x: Long => x } + val identityUDF = udf { (x: Long) => x } benchmark.addCase(s"With identity UDF", numIters = 5) { _ => spark.range(cardinality) .select( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala index d84fa5ec6a7f1..c90aabd37e3d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -51,7 +51,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { val intArrayToRow = intEncoder.createSerializer() val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt() } val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0) - val readIntArray = { i: Int => + val readIntArray = { (i: Int) => var n = 0 while (n < iters) { val len = intUnsafeArray.numElements() @@ -68,7 +68,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble() } val doubleArrayToRow = doubleEncoder.createSerializer() val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0) - val readDoubleArray = { i: Int => + val readDoubleArray = { (i: Int) => var n = 0 while (n < iters) { val len = doubleUnsafeArray.numElements() @@ -95,7 +95,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var intTotalLength: Int = 0 val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt() } val intArrayToRow = intEncoder.createSerializer() - val writeIntArray = { i: Int => + val writeIntArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { @@ -108,7 +108,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var doubleTotalLength: Int = 0 val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble() } val doubleArrayToRow = doubleEncoder.createSerializer() - val writeDoubleArray = { i: Int => + val writeDoubleArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { @@ -132,7 +132,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt() } val intArrayToRow = intEncoder.createSerializer() val intUnsafeArray = intArrayToRow(intPrimitiveArray).getArray(0) - val readIntArray = { i: Int => + val readIntArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { @@ -146,7 +146,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble() } val doubleArrayToRow = doubleEncoder.createSerializer() val doubleUnsafeArray = doubleArrayToRow(doublePrimitiveArray).getArray(0) - val readDoubleArray = { i: Int => + val readDoubleArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { @@ -169,7 +169,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var intTotalLen: Int = 0 val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt() } - val createIntArray = { i: Int => + val createIntArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { @@ -181,7 +181,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { var doubleTotalLen: Int = 0 val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble() } - val createDoubleArray = { i: Int => + val createDoubleArray = { (i: Int) => var len = 0 var n = 0 while (n < iters) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala index 46f60e881ddba..59adeb09f5d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} import org.apache.spark.sql.execution.ColumnarToRowExec import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 290cfd56b8bce..a3daae5ce9e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -93,7 +93,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) val label = s"${getFormattedClassName(scheme)}(${"%.3f".format(compressionRatio)})" - benchmark.addCase(label)({ i: Int => + benchmark.addCase(label)({ (i: Int) => for (n <- 0L until iters) { compressFunc(input, buf) input.rewind() @@ -120,7 +120,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem input.rewind() - benchmark.addCase(label)({ i: Int => + benchmark.addCase(label)({ (i: Int) => val rowBuf = new GenericInternalRow(1) for (n <- 0L until iters) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala index d0dd8e03e58ef..b26d777b8fa05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala @@ -138,13 +138,13 @@ class FileSourceCustomMetadataStructSuite extends QueryTest with SharedSparkSess test("[SPARK-43226] extra constant metadata fields with extractors") { withTempData("parquet", FILE_SCHEMA) { (_, f0, f1) => val format = new TestFileFormat(extraConstantMetadataFields) { - val extractPartitionNumber = { pf: PartitionedFile => + val extractPartitionNumber = { (pf: PartitionedFile) => pf.toPath.toString.split("/").collectFirst { case "f0" => 9990 case "f1" => 9991 }.get } - val extractPartitionName = { pf: PartitionedFile => + val extractPartitionName = { (pf: PartitionedFile) => pf.toPath.toString.split("/").collectFirst { case "f0" => "f0f" case "f1" => "f1f" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index 8018417f923af..acc0e12541514 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -53,7 +53,7 @@ class OrcV1FilterSuite extends OrcFilterSuite { LogicalRelationWithTable(orcRelation: HadoopFsRelation, _)) => maybeRelation = Some(orcRelation) filters - }.flatten.reduceLeftOption(And) + }.flatten.reduceLeftOption(And(_, _)) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") val (_, selectedFilters, _) = @@ -99,7 +99,7 @@ class OrcV1FilterSuite extends OrcFilterSuite { LogicalRelationWithTable(orcRelation: HadoopFsRelation, _)) => maybeRelation = Some(orcRelation) filters - }.flatten.reduceLeftOption(And) + }.flatten.reduceLeftOption(And(_, _)) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") val (_, selectedFilters, _) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala index f9eea3816fcaa..6f632bba730be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala @@ -57,7 +57,7 @@ class SQLLiveEntitiesEventFilterSuite extends SparkFunSuite { val jobEndEventForJob1 = SparkListenerJobEnd(1, 0, JobSucceeded) val stageSubmittedEventsForJob1 = SparkListenerStageSubmitted(stage1) val stageCompletedEventsForJob1 = SparkListenerStageCompleted(stage1) - val unpersistRDDEventsForJob1 = (1 to 2).map(SparkListenerUnpersistRDD) + val unpersistRDDEventsForJob1 = (1 to 2).map(SparkListenerUnpersistRDD(_)) // job events for finished job should be considered as "don't know" assert(None === acceptFn(jobStartEventForJob1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 69dd04e07d551..7a23d64a9f441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -423,7 +423,7 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils test("Broadcast timeout") { val timeout = 5 - val slowUDF = udf({ x: Int => Thread.sleep(timeout * 1000); x }) + val slowUDF = udf { (x: Int) => Thread.sleep(timeout * 1000); x } val df1 = spark.range(10).select($"id" as Symbol("a")) val df2 = spark.range(5).select(slowUDF($"id") as Symbol("a")) val testDf = df1.join(broadcast(df2), "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala index cdf736b1fffca..6c871ff15d4c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -151,7 +151,7 @@ abstract class CheckpointFileManagerTestsOnLocalFs extends CheckpointFileManagerTests with SQLHelper { protected def withTempHadoopPath(p: Path => Unit): Unit = { - withTempDir { f: File => + withTempDir { (f: File) => val basePath = new Path(f.getAbsolutePath) p(basePath) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 08648148b4af4..198ce96be6c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods +import org.json4s.jvalue2extractable import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 6685b140960d9..ddae5023e436b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.streaming.StreamTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 800a58f0c1d63..f1a4b7322c843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -1152,9 +1152,8 @@ class SimpleCustomDriverMetric extends CustomMetric { } } -class SimpleCustomDriverTaskMetric(value : Long) extends CustomTaskMetric { +class SimpleCustomDriverTaskMetric(override val value : Long) extends CustomTaskMetric { override def name(): String = "custom_driver_metric_partition_count" - override def value(): Long = value } class BytesWrittenCustomMetric extends CustomMetric { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index a11b20912e91f..792ba01df9b6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -48,7 +48,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { val count = 8 * 1000 // Accessing a java array. - val javaArray = { i: Int => + val javaArray = { (i: Int) => val data = new Array[Int](count) var sum = 0L for (n <- 0L until iters) { @@ -66,7 +66,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Accessing ByteBuffers - val byteBufferUnsafe = { i: Int => + val byteBufferUnsafe = { (i: Int) => val data = ByteBuffer.allocate(count * 4) var sum = 0L for (n <- 0L until iters) { @@ -84,7 +84,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Accessing offheap byte buffers - val directByteBuffer = { i: Int => + val directByteBuffer = { (i: Int) => val data = ByteBuffer.allocateDirect(count * 4).asIntBuffer() var sum = 0L for (n <- 0L until iters) { @@ -104,7 +104,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Accessing ByteBuffer using the typed APIs - val byteBufferApi = { i: Int => + val byteBufferApi = { (i: Int) => val data = ByteBuffer.allocate(count * 4) var sum = 0L for (n <- 0L until iters) { @@ -124,7 +124,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Using unsafe memory - val unsafeBuffer = { i: Int => + val unsafeBuffer = { (i: Int) => val data: Long = Platform.allocateMemory(count * 4) var sum = 0L for (n <- 0L until iters) { @@ -146,7 +146,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Access through the column API with on heap memory - val columnOnHeap = { i: Int => + val columnOnHeap = { (i: Int) => val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { @@ -165,7 +165,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Access through the column API with off heap memory - def columnOffHeap = { i: Int => { + def columnOffHeap = { (i: Int) => { val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { @@ -184,7 +184,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { }} // Access by directly getting the buffer backing the column. - val columnOffheapDirect = { i: Int => + val columnOffheapDirect = { (i: Int) => val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { @@ -207,7 +207,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Access by going through a batch of unsafe rows. - val unsafeRowOnheap = { i: Int => + val unsafeRowOnheap = { (i: Int) => val buffer = new Array[Byte](count * 16) var sum = 0L for (n <- 0L until iters) { @@ -228,7 +228,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Access by going through a batch of unsafe rows. - val unsafeRowOffheap = { i: Int => + val unsafeRowOffheap = { (i: Int) => val buffer = Platform.allocateMemory(count * 16) var sum = 0L for (n <- 0L until iters) { @@ -250,7 +250,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } // Adding values by appending, instead of putting. - val onHeapAppend = { i: Int => + val onHeapAppend = { (i: Int) => val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { @@ -287,7 +287,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong, output = output) - benchmark.addCase("Bitset") { i: Int => { + benchmark.addCase("Bitset") { (i: Int) => { val b = new BitSet(count) var sum = 0L for (n <- 0L until iters) { @@ -304,7 +304,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { } }} - benchmark.addCase("Byte Array") { i: Int => { + benchmark.addCase("Byte Array") { (i: Int) => { val b = new Array[Byte](count) var sum = 0L for (n <- 0L until iters) { @@ -345,7 +345,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { val data = Seq.fill(count)(randomString(minString, maxString)) .map(_.getBytes(StandardCharsets.UTF_8)).toArray - def column(memoryMode: MemoryMode) = { i: Int => + def column(memoryMode: MemoryMode) = { (i: Int) => val column = if (memoryMode == MemoryMode.OFF_HEAP) { new OffHeapColumnVector(count, BinaryType) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index baf99798965da..13571c60d04d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1936,11 +1936,11 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { None), Config( Some(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")))) - ).foreach { testCase: TestCase => - testCase.configs.foreach { config: Config => + ).foreach { (testCase: TestCase) => + testCase.configs.foreach { (config: Config) => // Run the test twice, once using SQL for the INSERT operations and again using DataFrames. for (useDataFrames <- Seq(false, true)) { - config.sqlConf.map { kv: (String, String) => + config.sqlConf.map { (kv: (String, String)) => withSQLConf(kv) { // Run the test with the pair of custom SQLConf values. runTest(testCase.dataSource, config.copy(useDataFrames = useDataFrames)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a753da116924d..72ef2028afe43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1671,7 +1671,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("keep1", "keep2"), AddTextFileData("keep3", src, tmp), CheckAnswer("keep1", "keep2", "keep3"), - AssertOnQuery("check getBatch") { execution: StreamExecution => + AssertOnQuery("check getBatch") { (execution: StreamExecution) => val _sources = PrivateMethod[Seq[Source]](Symbol("sources")) val fileSource = getSourcesFromStreamingQuery(execution).head @@ -2004,14 +2004,14 @@ class FileStreamSourceSuite extends FileStreamSourceTest { testStream(filtered)( AddTextFileData("keep1", src, tmp, tmpFilePrefix = "keep1"), CheckAnswer("keep1"), - AssertOnQuery("input file removed") { _: StreamExecution => + AssertOnQuery("input file removed") { _ => // it doesn't rename any file yet assertFileIsNotRemoved(src, "keep1") true }, AddTextFileData("keep2", src, tmp, tmpFilePrefix = "ke ep2 %"), CheckAnswer("keep1", "keep2"), - AssertOnQuery("input file removed") { _: StreamExecution => + AssertOnQuery("input file removed") { _ => // it renames input file for first batch, but not for second batch yet assertFileIsRemoved(src, "keep1") assertFileIsNotRemoved(src, "ke ep2 %") @@ -2020,7 +2020,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { }, AddTextFileData("keep3", src, tmp, tmpFilePrefix = "keep3"), CheckAnswer("keep1", "keep2", "keep3"), - AssertOnQuery("input file renamed") { _: StreamExecution => + AssertOnQuery("input file renamed") { _ => // it renames input file for second batch, but not third batch yet assertFileIsRemoved(src, "ke ep2 %") assertFileIsNotRemoved(src, "keep3") @@ -2064,14 +2064,14 @@ class FileStreamSourceSuite extends FileStreamSourceTest { testStream(filtered)( AddTextFileData("keep1", dirForKeep1, tmp, tmpFilePrefix = "keep1"), CheckAnswer("keep1"), - AssertOnQuery("input file archived") { _: StreamExecution => + AssertOnQuery("input file archived") { _ => // it doesn't rename any file yet assertFileIsNotMoved(dirForKeep1, expectedMovedDir1, "keep1") true }, AddTextFileData("keep2", dirForKeep2, tmp, tmpFilePrefix = "keep2 %"), CheckAnswer("keep1", "keep2"), - AssertOnQuery("input file archived") { _: StreamExecution => + AssertOnQuery("input file archived") { _ => // it renames input file for first batch, but not for second batch yet assertFileIsMoved(dirForKeep1, expectedMovedDir1, "keep1") assertFileIsNotMoved(dirForKeep2, expectedMovedDir2, "keep2 %") @@ -2079,7 +2079,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { }, AddTextFileData("keep3", dirForKeep3, tmp, tmpFilePrefix = "keep3"), CheckAnswer("keep1", "keep2", "keep3"), - AssertOnQuery("input file archived") { _: StreamExecution => + AssertOnQuery("input file archived") { _ => // it renames input file for second batch, but not third batch yet assertFileIsMoved(dirForKeep2, expectedMovedDir2, "keep2 %") assertFileIsNotMoved(dirForKeep3, expectedMovedDir3, "keep3") @@ -2088,7 +2088,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { }, AddTextFileData("keep4", dirForKeep3, tmp, tmpFilePrefix = "keep4"), CheckAnswer("keep1", "keep2", "keep3", "keep4"), - AssertOnQuery("input file archived") { _: StreamExecution => + AssertOnQuery("input file archived") { _ => // it renames input file for third batch, but not fourth batch yet assertFileIsMoved(dirForKeep3, expectedMovedDir3, "keep3") assertFileIsNotMoved(dirForKeep3, expectedMovedDir3, "keep4") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala index c417693b5d7a6..dd7ece3cf7171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala @@ -93,72 +93,70 @@ class ReportSinkMetricsSuite extends StreamTest { } } - case class TestSinkRelation(override val sqlContext: SQLContext, data: DataFrame) - extends BaseRelation { - override def schema: StructType = data.schema - } - - class TestSinkProvider extends SimpleTableProvider - with DataSourceRegister - with CreatableRelationProvider with Logging { +case class TestSinkRelation(override val sqlContext: SQLContext, data: DataFrame) + extends BaseRelation { + override def schema: StructType = data.schema +} - override def getTable(options: CaseInsensitiveStringMap): Table = { - val useCommitCoordinator = options.getBoolean("useCommitCoordinator", false) - new TestSinkTable(useCommitCoordinator) - } +class TestSinkProvider extends SimpleTableProvider + with DataSourceRegister + with CreatableRelationProvider with Logging { - def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { + override def getTable(options: CaseInsensitiveStringMap): Table = { + val useCommitCoordinator = options.getBoolean("useCommitCoordinator", false) + new TestSinkTable(useCommitCoordinator) + } - TestSinkRelation(sqlContext, data) - } + def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { - def shortName(): String = "test" + TestSinkRelation(sqlContext, data) } - class TestSinkTable(useCommitCoordinator: Boolean) - extends Table with SupportsWrite with ReportsSinkMetrics with Logging { + def shortName(): String = "test" +} - override def name(): String = "test" +class TestSinkTable(useCommitCoordinator: Boolean) + extends Table with SupportsWrite with ReportsSinkMetrics with Logging { - override def schema(): StructType = StructType(Nil) + override def name(): String = "test" - override def capabilities(): java.util.Set[TableCapability] = { - java.util.EnumSet.of(TableCapability.STREAMING_WRITE) - } + override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder with SupportsTruncate with SupportsStreamingUpdateAsAppend { + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.STREAMING_WRITE) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdateAsAppend { - override def truncate(): WriteBuilder = this + override def truncate(): WriteBuilder = this - override def build(): Write = { - new Write { - override def toStreaming: StreamingWrite = { - new TestSinkWrite(useCommitCoordinator) - } + override def build(): Write = { + new Write { + override def toStreaming: StreamingWrite = { + new TestSinkWrite(useCommitCoordinator) } } } } - - override def metrics(): java.util.Map[String, String] = { - Map("metrics-1" -> "value-1", "metrics-2" -> "value-2").asJava - } } - class TestSinkWrite(useCommitCoordinator: Boolean) - extends StreamingWrite with Logging with Serializable { + override def metrics(): java.util.Map[String, String] = { + Map("metrics-1" -> "value-1", "metrics-2" -> "value-2").asJava + } +} - def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = - PackedRowWriterFactory +class TestSinkWrite(override val useCommitCoordinator: Boolean) + extends StreamingWrite with Logging with Serializable { - override def useCommitCoordinator(): Boolean = useCommitCoordinator + def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = + PackedRowWriterFactory - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - } + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index b0967d5ffdf10..fa92354f59172 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1256,12 +1256,12 @@ class StreamSuite extends StreamTest { DateTimeUtils.millisToMicros(lastTimestamp), ZoneId.systemDefault) testStream(df) ( AddData(input, 1), - CheckLastBatch { rows: Seq[Row] => + CheckLastBatch { (rows: Seq[Row]) => lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 1) }, Execute { _ => Thread.sleep(1000) }, AddData(input, 2), - CheckLastBatch { rows: Seq[Row] => + CheckLastBatch { (rows: Seq[Row]) => lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 2) } ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ff23e00336a40..711f4ed5e0eaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -997,7 +997,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("Uuid in streaming query should not produce same uuids in each execution") { val uuids = mutable.ArrayBuffer[String]() - def collectUuid: Seq[Row] => Unit = { rows: Seq[Row] => + def collectUuid: Seq[Row] => Unit = { (rows: Seq[Row]) => rows.foreach(r => uuids += r.getString(0)) } @@ -1014,7 +1014,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("Rand/Randn in streaming query should not produce same results in each execution") { val rands = mutable.ArrayBuffer[Double]() - def collectRand: Seq[Row] => Unit = { rows: Seq[Row] => + def collectRand: Seq[Row] => Unit = { (rows: Seq[Row]) => rows.foreach { r => rands += r.getDouble(0) rands += r.getDouble(1) @@ -1034,7 +1034,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("Shuffle in streaming query should not produce same results in each execution") { val rands = mutable.ArrayBuffer[Seq[Int]]() - def collectShuffle: Seq[Row] => Unit = { rows: Seq[Row] => + def collectShuffle: Seq[Row] => Unit = { (rows: Seq[Row]) => rows.foreach { r => rands += r.getSeq[Int](0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 477da731b81b8..d503ddc515d19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -239,7 +239,7 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField(_)), 8) rdd.toDF().createOrReplaceTempView("withEmptyParts") rdd } diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala index 982d57fb28756..d3677d7fd3d7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import jakarta.servlet.http.HttpServletResponse import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods +import org.json4s.jvalue2extractable import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9f1954cbf6868..da27889aadb3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -536,7 +536,7 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) } else { @@ -544,7 +544,7 @@ private[hive] trait HiveInspectors { } } case hvoi: HiveVarcharObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) } else { @@ -552,7 +552,7 @@ private[hive] trait HiveInspectors { } } case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) } else { @@ -560,7 +560,7 @@ private[hive] trait HiveInspectors { } } case hvoi: HiveCharObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) } else { @@ -568,7 +568,7 @@ private[hive] trait HiveInspectors { } } case x: StringObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes val wObj = x.getPrimitiveWritableObject(data) @@ -579,7 +579,7 @@ private[hive] trait HiveInspectors { } } case x: StringObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { UTF8String.fromString(x.getPrimitiveJavaObject(data)) } else { @@ -587,35 +587,35 @@ private[hive] trait HiveInspectors { } } case x: IntObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: BooleanObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: FloatObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: DoubleObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: LongObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: ShortObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: ByteObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) x.get(data) else null } case x: HiveDecimalObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { HiveShim.toCatalystDecimal(x, data) } else { @@ -623,7 +623,7 @@ private[hive] trait HiveInspectors { } } case x: BinaryObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { x.getPrimitiveWritableObject(data).copyBytes() } else { @@ -631,7 +631,7 @@ private[hive] trait HiveInspectors { } } case x: DateObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { new DaysWritable(x.getPrimitiveWritableObject(data)).gregorianDays } else { @@ -639,7 +639,7 @@ private[hive] trait HiveInspectors { } } case x: DateObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) } else { @@ -647,7 +647,7 @@ private[hive] trait HiveInspectors { } } case x: TimestampObjectInspector if x.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { DateTimeUtils.fromJavaTimestamp(x.getPrimitiveWritableObject(data).getTimestamp) } else { @@ -655,7 +655,7 @@ private[hive] trait HiveInspectors { } } case ti: TimestampObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) } else { @@ -663,7 +663,7 @@ private[hive] trait HiveInspectors { } } case dt: HiveIntervalDayTimeObjectInspector if dt.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { val dayTime = dt.getPrimitiveWritableObject(data).getHiveIntervalDayTime IntervalUtils.durationToMicros( @@ -673,7 +673,7 @@ private[hive] trait HiveInspectors { } } case dt: HiveIntervalDayTimeObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { val dayTime = dt.getPrimitiveJavaObject(data) IntervalUtils.durationToMicros( @@ -683,7 +683,7 @@ private[hive] trait HiveInspectors { } } case ym: HiveIntervalYearMonthObjectInspector if ym.preferWritable() => - data: Any => { + (data: Any) => { if (data != null) { ym.getPrimitiveWritableObject(data).getHiveIntervalYearMonth.getTotalMonths } else { @@ -691,7 +691,7 @@ private[hive] trait HiveInspectors { } } case ym: HiveIntervalYearMonthObjectInspector => - data: Any => { + (data: Any) => { if (data != null) { ym.getPrimitiveJavaObject(data).getTotalMonths } else { @@ -699,7 +699,7 @@ private[hive] trait HiveInspectors { } } case _ => - data: Any => { + (data: Any) => { if (data != null) { pi.getPrimitiveJavaObject(data) } else { @@ -709,7 +709,7 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => val unwrapper = unwrapperFor(li.getListElementObjectInspector) - data: Any => { + (data: Any) => { if (data != null) { Option(li.getList(data)) .map { l => @@ -724,7 +724,7 @@ private[hive] trait HiveInspectors { case mi: MapObjectInspector => val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector) val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector) - data: Any => { + (data: Any) => { if (data != null) { val map = mi.getMap(data) if (map == null) { @@ -741,9 +741,9 @@ private[hive] trait HiveInspectors { val fields = si.getAllStructFieldRefs.asScala val unwrappers = fields.map { field => val unwrapper = unwrapperFor(field.getFieldObjectInspector) - data: Any => unwrapper(si.getStructFieldData(data, field)) + (data: Any) => unwrapper(si.getStructFieldData(data, field)) } - data: Any => { + (data: Any) => { if (data != null) { new GenericInternalRow(unwrappers.map(_(data)).toArray) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 57f6f999b6ade..e1637d729ec52 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -231,7 +231,7 @@ private[hive] class HiveClientImpl( val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong var numTries = 0 var caughtException: Exception = null - do { + while ({ numTries += 1 try { return f @@ -244,7 +244,8 @@ private[hive] class HiveClientImpl( clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) } - } while (numTries <= retryLimit && System.nanoTime < deadline) + numTries <= retryLimit && System.nanoTime < deadline + }) () if (System.nanoTime > deadline) { logWarning("Deadline exceeded") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 154d07f80d898..a4ce5632fe5d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,6 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 335d552fd50b7..3ad47764ac2f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -81,7 +81,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And(_, _)).map { pred => require(pred.dataType == BooleanType, s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " + s"${pred.dataType.catalogString}.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 91918fe62362b..98153a622b3ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.AttributeSeq import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index 6486904fe65af..9e778f9e47ce6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -84,7 +84,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) } if (sizeOfPartitions.forall(_ > 0)) { val filteredStats = - FilterEstimation(Filter(partitionKeyFilters.reduce(And), relation)).estimate + FilterEstimation(Filter(partitionKeyFilters.reduce(And(_, _)), relation)).estimate val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) => (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType)) }) @@ -114,7 +114,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) val newRelation = relation.copy( tableMeta = newTableMeta, prunedPartitions = Some(newPartitions)) // Keep partition filters so that they are visible in physical planning - Project(projections, Filter(filters.reduceLeft(And), newRelation)) + Project(projections, Filter(filters.reduceLeft(And(_, _)), newRelation)) } else { op } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index a89ea2424696e..b65a6638286a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -91,7 +91,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val literal = Literal.create(value, dataType) val attribute = inputAttributes.find(_.name == column).get expressions.GreaterThan(attribute, literal) - }.reduceOption(expressions.And).getOrElse(Literal(true)) + }.reduceOption(expressions.And.apply).getOrElse(Literal(true)) Predicate.create(filterCondition, inputAttributes) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4aeb0e043a973..9caf5372a46ca 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -46,7 +46,8 @@ import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSy */ private[streaming] class CheckpointInputDStream(_ssc: StreamingContext) extends InputDStream[Int](_ssc) { - protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + protected[streaming] override val checkpointData: FileInputDStreamCheckpointData = + new FileInputDStreamCheckpointData override def start(): Unit = { } override def stop(): Unit = { } override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1))) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 7fcc67b461419..82a36c603a977 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -127,7 +127,7 @@ class DStreamClosureSuite extends SparkFunSuite with LocalStreamingContext { private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { expectCorrectException { ds.combineByKey[Int]( - { _: Int => return; 1 }, + { (_: Int) => return; 1 }, { case (_: Int, _: Int) => return; 1 }, { case (_: Int, _: Int) => return; 1 }, new HashPartitioner(5) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 2079ed729c40c..30c0f3773d89b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -109,7 +109,7 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with M receivedBlockTracker.getBlocksOfBatch(1) shouldEqual Map(streamId -> blockInfos) receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos - val expectedWrittenData1 = blockInfos.map(BlockAdditionEvent) :+ + val expectedWrittenData1 = blockInfos.map(BlockAdditionEvent(_)) :+ BatchAllocationEvent(1, AllocatedBlocks(Map(streamId -> blockInfos))) getWrittenLogData() shouldEqual expectedWrittenData1 getWriteAheadLogFiles() should have size 1 @@ -216,7 +216,7 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with M tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 // Verify whether write ahead log has correct contents - val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent(_)) getWrittenLogData() shouldEqual expectedWrittenData1 getWriteAheadLogFiles() should have size 1 tracker1.stop() @@ -255,7 +255,7 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with M // Verify whether log has correct contents val expectedWrittenData2 = expectedWrittenData1 ++ Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ - blockInfos2.map(BlockAdditionEvent) ++ + blockInfos2.map(BlockAdditionEvent(_)) ++ Seq(createBatchAllocation(batchTime2, blockInfos2)) getWrittenLogData() shouldEqual expectedWrittenData2 @@ -319,14 +319,14 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with M // If we face any issue during recovery, because these old files exist, then we need to make // deletion more robust rather than a parallelized operation where we fire and forget val batch1Allocation = createBatchAllocation(t(1), batch1) - writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent(_)) :+ batch1Allocation) writeEventsManually(getLogFileName(t(2)), Seq(createBatchCleanup(t(1)))) val batch2Allocation = createBatchAllocation(t(3), batch2) - writeEventsManually(getLogFileName(t(3)), batch2.map(BlockAdditionEvent) :+ batch2Allocation) + writeEventsManually(getLogFileName(t(3)), batch2.map(BlockAdditionEvent(_)) :+ batch2Allocation) - writeEventsManually(getLogFileName(t(4)), batch3.map(BlockAdditionEvent)) + writeEventsManually(getLogFileName(t(4)), batch3.map(BlockAdditionEvent(_))) // We should have 5 different log files as we called `writeEventsManually` with 5 different // timestamps @@ -360,7 +360,7 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with M compareTrackers(tracker, tracker3) // rewrite second file - writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent(_)) :+ batch1Allocation) assert(getWriteAheadLogFiles().length === 5) // make sure trackers are consistent val tracker4 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) From 744472efd6e1190b963b7a26bb876c250400fea1 Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Wed, 2 Apr 2025 01:52:23 -0400 Subject: [PATCH 2/4] Make MiMa happy without the -Xsource:3-cross option --- core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../spark/mllib/classification/LogisticRegression.scala | 2 +- .../scala/org/apache/spark/mllib/classification/SVM.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 8f03afb5b2664..c2562446e14c1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -87,7 +87,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } - override val partitioner: Option[Partitioner] = Some(part) + override val partitioner: Some[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e4d1c70896e88..b6bb5dfc872b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -253,7 +253,7 @@ class LogisticRegressionWithLBFGS @Since("1.1.0") override val optimizer: LBFGS = new LBFGS(new LogisticGradient, new SquaredL2Updater) - override protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List(multiLabelValidator) + override protected val validators: List[RDD[LabeledPoint] => Boolean] = List(multiLabelValidator) private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => if (numOfLinearPredictor > 1) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 2bf411c96e4e7..36ec7faadd160 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -142,8 +142,8 @@ class SVMWithSGD private ( .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override protected val validators: Seq[RDD[LabeledPoint] => Boolean] = - Seq(DataValidators.binaryLabelValidator) + override protected val validators: List[RDD[LabeledPoint] => Boolean] = + List(DataValidators.binaryLabelValidator) /** * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, From 0d519ac13bb1df90fe080f5c921a68e0174b1add Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Wed, 2 Apr 2025 05:31:15 -0400 Subject: [PATCH 3/4] Make MiMa happy with the -Xsource:3 -Xsource-features:v2.13.15,-case-companion-function options --- .../classification/LogisticRegression.scala | 3 +- pom.xml | 3 +- project/MimaExcludes.scala | 33 ++++++++++++++++++- project/SparkBuild.scala | 3 +- .../k8s/ExecutorPodsSnapshotSuite.scala | 30 ++++++++--------- .../apache/spark/sql/MergeIntoWriter.scala | 4 +-- .../client/arrow/ArrowEncoderSuite.scala | 2 +- .../datasources/FileSourceStrategy.scala | 5 +-- .../PartitioningAwareFileIndex.scala | 2 +- .../sources/TextSocketMicroBatchStream.scala | 2 +- 10 files changed, 60 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b6bb5dfc872b9..4e7e3a9c26537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -221,7 +221,8 @@ class LogisticRegressionWithSGD private[mllib] ( .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override protected val validators = List(DataValidators.binaryLabelValidator) + override protected val validators: List[RDD[LabeledPoint] => Boolean] = + List(DataValidators.binaryLabelValidator) override protected[mllib] def createModel( weights: Vector, diff --git a/pom.xml b/pom.xml index f8464069e7588..6cf348e9dcf54 100644 --- a/pom.xml +++ b/pom.xml @@ -2736,7 +2736,8 @@ -explaintypes -release 17 - -Xsource:3-cross + -Xsource:3 + -Xsource-features:v2.13.15,-case-companion-function -Wconf:any:e -Wconf:cat=deprecation:wv -Wunused:imports diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d89bb285ed8dc..6690b932a1e90 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,7 +40,7 @@ object MimaExcludes { ) // Exclude rules for 4.0.x from 3.5.0 - lazy val v40excludes = defaultExcludes ++ Seq( + lazy val v40excludes = defaultExcludes ++ scala3Excludes ++ Seq( // [SPARK-44863][UI] Add a button to download thread dump as a txt in Spark UI ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ThreadStackTrace.*"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.ThreadStackTrace$"), @@ -238,6 +238,37 @@ object MimaExcludes { loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") + // Enable -Xsource:3 compiler flag + lazy val scala3Excludes = Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Metric.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Metric.tupled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Metric.curried"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.sql.Metric$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Node.apply$default$3"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Node.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Node.tupled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.sql.Node.curried"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.sql.Node$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.paths.SparkPath.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.paths.SparkPath.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.paths.SparkPath.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.paths.SparkPath.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.apply$default$7"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.tupled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.curried"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.AccumulableInfo$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.apply$default$7"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.tupled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.curried"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationInfo.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationInfo.tupled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationInfo.curried"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.status.api.v1.ApplicationInfo$"), + ) + // Default exclude rules lazy val defaultExcludes = Seq( // Spark Internals diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1e3e7d25f3a57..5c1af19fea232 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -229,7 +229,8 @@ object SparkBuild extends PomBuild { lazy val compilerWarningSettings: Seq[sbt.Def.Setting[_]] = Seq( (Compile / scalacOptions) ++= { Seq( - "-Xsource:3-cross", + "-Xsource:3", + "-Xsource-features:v2.13.15,-case-companion-function", // replace -Xfatal-warnings with fine-grained configuration, since 2.13.2 // verbose warning on deprecation, error on all others // see `scalac -Wconf:help` for details diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala index 5e66726927e36..6a7581657d12c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -38,14 +38,14 @@ class ExecutorPodsSnapshotSuite extends SparkFunSuite { test("States are interpreted correctly from pod metadata.") { ExecutorPodsSnapshot.setShouldCheckAllContainers(false) val testCases = Seq( - testCase(pendingExecutor(0), PodPending), - testCase(runningExecutor(1), PodRunning), - testCase(succeededExecutor(2), PodSucceeded), - testCase(failedExecutorWithoutDeletion(3), PodFailed), - testCase(deletedExecutor(4), PodDeleted), - testCase(unknownExecutor(5), PodUnknown), - testCase(finishedExecutorWithRunningSidecar(6, 0), PodSucceeded), - testCase(finishedExecutorWithRunningSidecar(7, 1), PodFailed) + testCase(pendingExecutor(0), PodPending(_)), + testCase(runningExecutor(1), PodRunning(_)), + testCase(succeededExecutor(2), PodSucceeded(_)), + testCase(failedExecutorWithoutDeletion(3), PodFailed(_)), + testCase(deletedExecutor(4), PodDeleted(_)), + testCase(unknownExecutor(5), PodUnknown(_)), + testCase(finishedExecutorWithRunningSidecar(6, 0), PodSucceeded(_)), + testCase(finishedExecutorWithRunningSidecar(7, 1), PodFailed(_)) ) doTest(testCases) } @@ -54,13 +54,13 @@ class ExecutorPodsSnapshotSuite extends SparkFunSuite { + " when configured to check all containers.") { ExecutorPodsSnapshot.setShouldCheckAllContainers(true) val testCases = Seq( - testCase(pendingExecutor(0), PodPending), - testCase(runningExecutor(1), PodRunning), - testCase(runningExecutorWithFailedContainer(2), PodFailed), - testCase(succeededExecutor(3), PodSucceeded), - testCase(failedExecutorWithoutDeletion(4), PodFailed), - testCase(deletedExecutor(5), PodDeleted), - testCase(unknownExecutor(6), PodUnknown) + testCase(pendingExecutor(0), PodPending(_)), + testCase(runningExecutor(1), PodRunning(_)), + testCase(runningExecutorWithFailedContainer(2), PodFailed(_)), + testCase(succeededExecutor(3), PodSucceeded(_)), + testCase(failedExecutorWithoutDeletion(4), PodFailed(_)), + testCase(deletedExecutor(5), PodDeleted(_)), + testCase(unknownExecutor(6), PodUnknown(_)) ) doTest(testCases) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index 2021a91aadc9e..a2d53aa8cc8a9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -250,9 +250,7 @@ case class WhenMatched[T] private[sql] ( * @tparam T * The type of data in the MergeIntoWriter. */ -case class WhenNotMatched[T]( - mergeIntoWriter: MergeIntoWriter[T], - condition: Option[Column]) { +case class WhenNotMatched[T](mergeIntoWriter: MergeIntoWriter[T], condition: Option[Column]) { /** * Specifies an action to insert all non-matched rows into the DataFrame. diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 58e19389cae2e..a854905beed09 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -816,7 +816,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { test("REPL generated classes") { val encoder = ScalaReflection.encoderFor[MyTestClass] roundTripAndCheckIdentical(encoder) { () => - Iterator.tabulate(10)(MyTestClass) + Iterator.tabulate(10)(MyTestClass(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 7291da248294a..66128319d41a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -134,7 +134,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val numBuckets = bucketSpec.numBuckets val normalizedFiltersAndExpr = normalizedFilters - .reduce(expressions.And) + .reduce(expressions.And(_, _)) val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, numBuckets) @@ -353,7 +353,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { }.getOrElse(scan) // bottom-most filters are put in the left of the list. - val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters + val finalFilters = + afterScanFilters.toSeq.reduceOption(expressions.And(_, _)).toSeq ++ stayUpFilters val withFilter = finalFilters.foldLeft(withMetadataProjections)((plan, cond) => { execution.FilterExec(cond, plan) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 07be3f89872cc..ea5ce47e774ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -184,7 +184,7 @@ abstract class PartitioningAwareFileIndex( } if (partitionPruningPredicates.nonEmpty) { - val predicate = partitionPruningPredicates.reduce(expressions.And) + val predicate = partitionPruningPredicates.reduce(expressions.And(_, _)) val boundPredicate = Predicate.createInterpreted(predicate.transform { case a: AttributeReference => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala index 597b981ebe556..1519381fc6952 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala @@ -128,7 +128,7 @@ class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int) slices(idx % numPartitions).append(r) } - slices.map(TextSocketInputPartition) + slices.map(TextSocketInputPartition(_)) } override def createReaderFactory(): PartitionReaderFactory = From 31fdb05dae91c855eec1e738451ae2e7ab7e1dbc Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Tue, 8 Apr 2025 12:57:02 -0400 Subject: [PATCH 4/4] Apply the changes to the new code after rebase --- .../resolver/RelationMetadataProvider.scala | 2 +- .../expressions/V2ExpressionUtils.scala | 96 +++++++++---------- 2 files changed, 49 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala index 8ec1db53a10b4..6a3d8e161f2c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.HashMap -import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, RelationResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.LookupCatalog import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 7cc03f3ac3fa6..608a4b45ca29a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -227,21 +227,21 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { private def convertPredicate(expr: GeneralScalarExpression): Option[Expression] = { expr.name match { - case "IS_NULL" => convertUnaryExpr(expr, IsNull) - case "IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull) - case "NOT" => convertUnaryExpr(expr, Not) - case "=" => convertBinaryExpr(expr, EqualTo) - case "<=>" => convertBinaryExpr(expr, EqualNullSafe) - case ">" => convertBinaryExpr(expr, GreaterThan) - case ">=" => convertBinaryExpr(expr, GreaterThanOrEqual) - case "<" => convertBinaryExpr(expr, LessThan) - case "<=" => convertBinaryExpr(expr, LessThanOrEqual) + case "IS_NULL" => convertUnaryExpr(expr, IsNull(_)) + case "IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull(_)) + case "NOT" => convertUnaryExpr(expr, Not(_)) + case "=" => convertBinaryExpr(expr, EqualTo(_, _)) + case "<=>" => convertBinaryExpr(expr, EqualNullSafe(_, _)) + case ">" => convertBinaryExpr(expr, GreaterThan(_, _)) + case ">=" => convertBinaryExpr(expr, GreaterThanOrEqual(_, _)) + case "<" => convertBinaryExpr(expr, LessThan(_, _)) + case "<=" => convertBinaryExpr(expr, LessThanOrEqual(_, _)) case "<>" => convertBinaryExpr(expr, (left, right) => Not(EqualTo(left, right))) - case "AND" => convertBinaryExpr(expr, And) - case "OR" => convertBinaryExpr(expr, Or) - case "STARTS_WITH" => convertBinaryExpr(expr, StartsWith) - case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith) - case "CONTAINS" => convertBinaryExpr(expr, Contains) + case "AND" => convertBinaryExpr(expr, And(_, _)) + case "OR" => convertBinaryExpr(expr, Or(_, _)) + case "STARTS_WITH" => convertBinaryExpr(expr, StartsWith(_, _)) + case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith(_, _)) + case "CONTAINS" => convertBinaryExpr(expr, Contains(_, _)) case "IN" => convertExpr(expr, children => In(children.head, children.tail)) case _ => None } @@ -278,9 +278,9 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { case "/" => convertBinaryExpr(expr, Divide(_, _, evalMode = EvalMode.ANSI)) case "%" => convertBinaryExpr(expr, Remainder(_, _, evalMode = EvalMode.ANSI)) case "ABS" => convertUnaryExpr(expr, Abs(_, failOnError = true)) - case "COALESCE" => convertExpr(expr, Coalesce) - case "GREATEST" => convertExpr(expr, Greatest) - case "LEAST" => convertExpr(expr, Least) + case "COALESCE" => convertExpr(expr, Coalesce(_)) + case "GREATEST" => convertExpr(expr, Greatest(_)) + case "LEAST" => convertExpr(expr, Least(_)) case "RAND" => if (expr.children.isEmpty) { Some(new Rand()) @@ -289,20 +289,20 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { } else { None } - case "LOG" => convertBinaryExpr(expr, Logarithm) - case "LOG10" => convertUnaryExpr(expr, Log10) - case "LOG2" => convertUnaryExpr(expr, Log2) - case "LN" => convertUnaryExpr(expr, Log) - case "EXP" => convertUnaryExpr(expr, Exp) - case "POWER" => convertBinaryExpr(expr, Pow) - case "SQRT" => convertUnaryExpr(expr, Sqrt) - case "FLOOR" => convertUnaryExpr(expr, Floor) - case "CEIL" => convertUnaryExpr(expr, Ceil) + case "LOG" => convertBinaryExpr(expr, Logarithm(_, _)) + case "LOG10" => convertUnaryExpr(expr, Log10(_)) + case "LOG2" => convertUnaryExpr(expr, Log2(_)) + case "LN" => convertUnaryExpr(expr, Log(_)) + case "EXP" => convertUnaryExpr(expr, Exp(_)) + case "POWER" => convertBinaryExpr(expr, Pow(_, _)) + case "SQRT" => convertUnaryExpr(expr, Sqrt(_)) + case "FLOOR" => convertUnaryExpr(expr, Floor(_)) + case "CEIL" => convertUnaryExpr(expr, Ceil(_)) case "ROUND" => convertBinaryExpr(expr, Round(_, _, ansiEnabled = true)) - case "CBRT" => convertUnaryExpr(expr, Cbrt) - case "DEGREES" => convertUnaryExpr(expr, ToDegrees) - case "RADIANS" => convertUnaryExpr(expr, ToRadians) - case "SIGN" => convertUnaryExpr(expr, Signum) + case "CBRT" => convertUnaryExpr(expr, Cbrt(_)) + case "DEGREES" => convertUnaryExpr(expr, ToDegrees(_)) + case "RADIANS" => convertUnaryExpr(expr, ToRadians(_)) + case "SIGN" => convertUnaryExpr(expr, Signum(_)) case "WIDTH_BUCKET" => convertExpr( expr, @@ -313,30 +313,30 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { private def convertTrigonometricFunc(expr: GeneralScalarExpression): Option[Expression] = { expr.name match { - case "SIN" => convertUnaryExpr(expr, Sin) - case "SINH" => convertUnaryExpr(expr, Sinh) - case "COS" => convertUnaryExpr(expr, Cos) - case "COSH" => convertUnaryExpr(expr, Cosh) - case "TAN" => convertUnaryExpr(expr, Tan) - case "TANH" => convertUnaryExpr(expr, Tanh) - case "COT" => convertUnaryExpr(expr, Cot) - case "ASIN" => convertUnaryExpr(expr, Asin) - case "ASINH" => convertUnaryExpr(expr, Asinh) - case "ACOS" => convertUnaryExpr(expr, Acos) - case "ACOSH" => convertUnaryExpr(expr, Acosh) - case "ATAN" => convertUnaryExpr(expr, Atan) - case "ATANH" => convertUnaryExpr(expr, Atanh) - case "ATAN2" => convertBinaryExpr(expr, Atan2) + case "SIN" => convertUnaryExpr(expr, Sin(_)) + case "SINH" => convertUnaryExpr(expr, Sinh(_)) + case "COS" => convertUnaryExpr(expr, Cos(_)) + case "COSH" => convertUnaryExpr(expr, Cosh(_)) + case "TAN" => convertUnaryExpr(expr, Tan(_)) + case "TANH" => convertUnaryExpr(expr, Tanh(_)) + case "COT" => convertUnaryExpr(expr, Cot(_)) + case "ASIN" => convertUnaryExpr(expr, Asin(_)) + case "ASINH" => convertUnaryExpr(expr, Asinh(_)) + case "ACOS" => convertUnaryExpr(expr, Acos(_)) + case "ACOSH" => convertUnaryExpr(expr, Acosh(_)) + case "ATAN" => convertUnaryExpr(expr, Atan(_)) + case "ATANH" => convertUnaryExpr(expr, Atanh(_)) + case "ATAN2" => convertBinaryExpr(expr, Atan2(_, _)) case _ => None } } private def convertBitwiseFunc(expr: GeneralScalarExpression): Option[Expression] = { expr.name match { - case "~" => convertUnaryExpr(expr, BitwiseNot) - case "&" => convertBinaryExpr(expr, BitwiseAnd) - case "|" => convertBinaryExpr(expr, BitwiseOr) - case "^" => convertBinaryExpr(expr, BitwiseXor) + case "~" => convertUnaryExpr(expr, BitwiseNot(_)) + case "&" => convertBinaryExpr(expr, BitwiseAnd(_, _)) + case "|" => convertBinaryExpr(expr, BitwiseOr(_, _)) + case "^" => convertBinaryExpr(expr, BitwiseXor(_, _)) case _ => None } }