Skip to content

Commit 5ae573f

Browse files
ganeshas-dbHeartSaVioR
authored andcommitted
[SPARK-53656][SS] Refactor MemoryStream to use SparkSession instead of SQLContext
### What changes were proposed in this pull request? Refactor MemoryStream to use SparkSession instead of SQLContext. ### Why are the changes needed? SQLContext is deprecated in newer versions of Spark. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Verified that the affected tests are passing successfully. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52402 from ganeshashree/SPARK-53656. Authored-by: Ganesha S <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 5edebd2 commit 5ae573f

25 files changed

+283
-100
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy
2424
import scala.collection.mutable.ListBuffer
2525

2626
import org.apache.spark.internal.Logging
27-
import org.apache.spark.sql.{Encoder, SQLContext}
27+
import org.apache.spark.sql.{Encoder, SparkSession, SQLContext}
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
3030
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -43,32 +43,53 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider
4343
import org.apache.spark.sql.types.StructType
4444
import org.apache.spark.sql.util.CaseInsensitiveStringMap
4545

46-
object MemoryStream {
46+
object MemoryStream extends LowPriorityMemoryStreamImplicits {
4747
protected val currentBlockId = new AtomicInteger(0)
4848
protected val memoryStreamId = new AtomicInteger(0)
4949

50-
def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
51-
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
50+
def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] =
51+
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
5252

53-
def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
54-
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
53+
def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] =
54+
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions))
55+
}
56+
57+
/**
58+
* Provides lower-priority implicits for MemoryStream to prevent ambiguity when both
59+
* SparkSession and SQLContext are in scope. The implicits in the companion object,
60+
* which use SparkSession, take higher precedence.
61+
*/
62+
trait LowPriorityMemoryStreamImplicits {
63+
this: MemoryStream.type =>
64+
65+
// Deprecated: Used when an implicit SQLContext is in scope
66+
@deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
67+
def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] =
68+
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
69+
70+
@deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
71+
def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
72+
new MemoryStream[A](
73+
memoryStreamId.getAndIncrement(),
74+
sqlContext.sparkSession,
75+
Some(numPartitions))
5576
}
5677

5778
/**
5879
* A base class for memory stream implementations. Supports adding data and resetting.
5980
*/
60-
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream {
81+
abstract class MemoryStreamBase[A : Encoder](sparkSession: SparkSession) extends SparkDataStream {
6182
val encoder = encoderFor[A]
6283
protected val attributes = toAttributes(encoder.schema)
6384

6485
protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer()
6586

6687
def toDS(): Dataset[A] = {
67-
Dataset[A](sqlContext.sparkSession, logicalPlan)
88+
Dataset[A](sparkSession, logicalPlan)
6889
}
6990

7091
def toDF(): DataFrame = {
71-
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
92+
Dataset.ofRows(sparkSession, logicalPlan)
7293
}
7394

7495
def addData(data: A*): OffsetV2 = {
@@ -156,16 +177,16 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w
156177
*/
157178
case class MemoryStream[A : Encoder](
158179
id: Int,
159-
sqlContext: SQLContext,
180+
sparkSession: SparkSession,
160181
numPartitions: Option[Int] = None)
161182
extends MemoryStreamBaseClass[A](
162-
id, sqlContext, numPartitions = numPartitions)
183+
id, sparkSession, numPartitions = numPartitions)
163184

164185
abstract class MemoryStreamBaseClass[A: Encoder](
165186
id: Int,
166-
sqlContext: SQLContext,
187+
sparkSession: SparkSession,
167188
numPartitions: Option[Int] = None)
168-
extends MemoryStreamBase[A](sqlContext)
189+
extends MemoryStreamBase[A](sparkSession)
169190
with MicroBatchStream
170191
with SupportsTriggerAvailableNow
171192
with Logging {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization
2727

2828
import org.apache.spark.{SparkEnv, TaskContext}
2929
import org.apache.spark.rpc.RpcEndpointRef
30-
import org.apache.spark.sql.{Encoder, SQLContext}
30+
import org.apache.spark.sql.{Encoder, SparkSession, SQLContext}
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3333
import org.apache.spark.sql.connector.read.InputPartition
@@ -44,8 +44,11 @@ import org.apache.spark.util.RpcUtils
4444
* ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at
4545
* the specified offset within the list, or null if that offset doesn't yet have a record.
4646
*/
47-
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
48-
extends MemoryStreamBase[A](sqlContext) with ContinuousStream {
47+
class ContinuousMemoryStream[A : Encoder](
48+
id: Int,
49+
sparkSession: SparkSession,
50+
numPartitions: Int = 2)
51+
extends MemoryStreamBase[A](sparkSession) with ContinuousStream {
4952

5053
private implicit val formats: Formats = Serialization.formats(NoTypeHints)
5154

@@ -109,14 +112,47 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
109112
override def commit(end: Offset): Unit = {}
110113
}
111114

112-
object ContinuousMemoryStream {
115+
object ContinuousMemoryStream extends LowPriorityContinuousMemoryStreamImplicits {
113116
protected val memoryStreamId = new AtomicInteger(0)
114117

115-
def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
116-
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
118+
def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
119+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
117120

118-
def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
119-
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
121+
def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession):
122+
ContinuousMemoryStream[A] =
123+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions)
124+
125+
def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
126+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
127+
}
128+
129+
/**
130+
* Provides lower-priority implicits for ContinuousMemoryStream to prevent ambiguity when both
131+
* SparkSession and SQLContext are in scope. The implicits in the companion object,
132+
* which use SparkSession, take higher precedence.
133+
*/
134+
trait LowPriorityContinuousMemoryStreamImplicits {
135+
this: ContinuousMemoryStream.type =>
136+
137+
// Deprecated: Used when an implicit SQLContext is in scope
138+
@deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
139+
"instead of SQLContext", "4.1.0")
140+
def apply[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
141+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
142+
143+
@deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
144+
"instead of SQLContext", "4.1.0")
145+
def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
146+
ContinuousMemoryStream[A] =
147+
new ContinuousMemoryStream[A](
148+
memoryStreamId.getAndIncrement(),
149+
sqlContext.sparkSession,
150+
numPartitions)
151+
152+
@deprecated("Use ContinuousMemoryStream.singlePartition with an implicit SparkSession " +
153+
"instead of SQLContext", "4.1.0")
154+
def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
155+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1)
120156
}
121157

122158
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization
2727

2828
import org.apache.spark.{SparkEnv, TaskContext}
2929
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
30-
import org.apache.spark.sql.{Encoder, SQLContext}
30+
import org.apache.spark.sql.{Encoder, SparkSession, SQLContext}
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3333
import org.apache.spark.sql.connector.read.InputPartition
@@ -60,10 +60,10 @@ import org.apache.spark.util.{Clock, RpcUtils}
6060
*/
6161
class LowLatencyMemoryStream[A: Encoder](
6262
id: Int,
63-
sqlContext: SQLContext,
63+
sparkSession: SparkSession,
6464
numPartitions: Int = 2,
6565
clock: Clock = LowLatencyClock.getClock)
66-
extends MemoryStreamBaseClass[A](0, sqlContext)
66+
extends MemoryStreamBaseClass[A](0, sparkSession)
6767
with SupportsRealTimeMode {
6868
private implicit val formats: Formats = Serialization.formats(NoTypeHints)
6969

@@ -172,23 +172,53 @@ class LowLatencyMemoryStream[A: Encoder](
172172
}
173173
}
174174

175-
object LowLatencyMemoryStream {
175+
object LowLatencyMemoryStream extends LowPriorityLowLatencyMemoryStreamImplicits {
176176
protected val memoryStreamId = new AtomicInteger(0)
177177

178-
def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
179-
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
178+
def apply[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] =
179+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
180180

181181
def apply[A: Encoder](numPartitions: Int)(
182182
implicit
183-
sqlContext: SQLContext): LowLatencyMemoryStream[A] =
183+
sparkSession: SparkSession): LowLatencyMemoryStream[A] =
184184
new LowLatencyMemoryStream[A](
185185
memoryStreamId.getAndIncrement(),
186-
sqlContext,
186+
sparkSession,
187187
numPartitions = numPartitions
188188
)
189189

190-
def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
191-
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
190+
def singlePartition[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] =
191+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
192+
}
193+
194+
/**
195+
* Provides lower-priority implicits for LowLatencyMemoryStream to prevent ambiguity when both
196+
* SparkSession and SQLContext are in scope. The implicits in the companion object,
197+
* which use SparkSession, take higher precedence.
198+
*/
199+
trait LowPriorityLowLatencyMemoryStreamImplicits {
200+
this: LowLatencyMemoryStream.type =>
201+
202+
// Deprecated: Used when an implicit SQLContext is in scope
203+
@deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
204+
"instead of SQLContext", "4.1.0")
205+
def apply[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
206+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
207+
208+
@deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
209+
"instead of SQLContext", "4.1.0")
210+
def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
211+
LowLatencyMemoryStream[A] =
212+
new LowLatencyMemoryStream[A](
213+
memoryStreamId.getAndIncrement(),
214+
sqlContext.sparkSession,
215+
numPartitions = numPartitions
216+
)
217+
218+
@deprecated("Use LowLatencyMemoryStream.singlePartition with an implicit SparkSession " +
219+
"instead of SQLContext", "4.1.0")
220+
def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
221+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1)
192222
}
193223

194224
/**

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ class DatasetSuite extends QueryTest
10121012
assert(err.getMessage.contains("An Observation can be used with a Dataset only once"))
10131013

10141014
// streaming datasets are not supported
1015-
val streamDf = new MemoryStream[Int](0, sqlContext).toDF()
1015+
val streamDf = new MemoryStream[Int](0, spark).toDF()
10161016
val streamObservation = Observation("stream")
10171017
val streamErr = intercept[IllegalArgumentException] {
10181018
streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val"))

0 commit comments

Comments
 (0)