Skip to content

Commit 41a2748

Browse files
AleksandrPavlenkosrgg
authored andcommitted
Removed usage of SQLContext
1 parent baf262f commit 41a2748

29 files changed

+312
-361
lines changed

connector/python/tests/pyspark_tests_fixtures.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,39 @@
22
import pytest
33
import findspark
44
findspark.init()
5-
from pyspark import SparkContext, SparkConf, SQLContext, Row
5+
from pyspark import SparkContext, SparkConf, Row
6+
from pyspark.sql import SparkSession
67
import riak, pyspark_riak
78

89
@pytest.fixture(scope="session")
910
def docker_cli(request):
1011
# Start spark context to get access to py4j gateway
1112
conf = SparkConf().setMaster("local[*]").setAppName("pytest-pyspark-py4j")
12-
sc = SparkContext(conf=conf)
13+
sparkSession = SparkSession.builder.config(conf).getOrCreate()
14+
sc = sparkSession.sparkContext
1315
docker_cli = sc._gateway.jvm.com.basho.riak.test.cluster.DockerRiakCluster(1, 2)
1416
docker_cli.start()
1517
sc.stop()
1618
# Start spark context since it's not aware of riak nodes and thus can't be used to test riak
1719
request.addfinalizer(lambda: docker_cli.stop())
1820
return docker_cli
1921

22+
@pytest.fixture(scope="session")
23+
def spark_session(request):
24+
if not os.environ.has_key('RIAK_HOSTS'):
25+
docker_cli = request.getfuncargvalue('docker_cli')
26+
host_and_port = get_host_and_port(docker_cli)
27+
os.environ['RIAK_HOSTS'] = host_and_port
28+
os.environ['USE_DOCKER'] = 'true'
29+
# Start new spark context
30+
conf = SparkConf().setMaster('local[*]').setAppName('pytest-pyspark-local-testing')
31+
conf.set('spark.riak.connection.host', os.environ['RIAK_HOSTS'])
32+
conf.set('spark.driver.memory', '4g')
33+
conf.set('spark.executor.memory', '4g')
34+
spark_session = SparkSession.builder.config(conf=conf).getOrCreate()
35+
return spark_session
36+
37+
2038
@pytest.fixture(scope="session")
2139
def spark_context(request):
2240
# If RIAK_HOSTS is not set, use Docker to start a Riak node
@@ -30,16 +48,12 @@ def spark_context(request):
3048
conf.set('spark.riak.connection.host', os.environ['RIAK_HOSTS'])
3149
conf.set('spark.driver.memory', '4g')
3250
conf.set('spark.executor.memory', '4g')
33-
spark_context = SparkContext(conf=conf)
51+
spark_context = SparkSession.builder.config(conf=conf).getOrCreate().sparkContext
3452
spark_context.setLogLevel('INFO')
3553
pyspark_riak.riak_context(spark_context)
3654
request.addfinalizer(lambda: spark_context.stop())
3755
return spark_context
3856

39-
@pytest.fixture(scope="session")
40-
def sql_context(request, spark_context):
41-
sqlContext = SQLContext(spark_context)
42-
return sqlContext
4357

4458
@pytest.fixture(scope="session")
4559
def riak_client(request):

connector/python/tests/test_pyspark_riak.py

Lines changed: 61 additions & 60 deletions
Large diffs are not rendered by default.

connector/src/main/scala/com/basho/riak/spark/rdd/ReadConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ case class ReadConf (
3636
/**
3737
* Used only in ranged partitioner to identify quantized field.
3838
* Usage example:
39-
* sqlContext.read
39+
* sparkSession.read
4040
* .option("spark.riak.partitioning.ts-range-field-name", "time")
4141
* Providing this property automatically turns on RangedRiakTSPartitioner
4242
*/

connector/src/test/java/com/basho/riak/spark/rdd/AbstractJavaSparkTest.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,7 @@
3232

3333
public abstract class AbstractJavaSparkTest extends AbstractRiakSparkTest {
3434
// JavaSparkContext, created per test case
35-
protected JavaSparkContext jsc = null;
36-
37-
@Override
38-
public SparkContext createSparkContext(SparkConf conf) {
39-
final SparkContext sc = new SparkContext(conf);
40-
jsc = new JavaSparkContext(sc);
41-
return sc;
42-
}
35+
protected JavaSparkContext jsc = new JavaSparkContext(sparkSession().sparkContext());
4336

4437
protected static class FuncReMapWithPartitionIdx<T> implements Function2<Integer, Iterator<T>, Iterator<Tuple2<Integer, T>>> {
4538
@Override

connector/src/test/java/com/basho/riak/spark/rdd/timeseries/AbstractJavaTimeSeriesTest.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,12 @@
88
public abstract class AbstractJavaTimeSeriesTest extends AbstractTimeSeriesTest {
99

1010
// JavaSparkContext, created per test case
11-
protected JavaSparkContext jsc = null;
11+
protected JavaSparkContext jsc = new JavaSparkContext(sparkSession().sparkContext());
1212

1313
public AbstractJavaTimeSeriesTest(boolean createTestDate) {
1414
super(createTestDate);
1515
}
1616

17-
@Override
18-
public SparkContext createSparkContext(SparkConf conf) {
19-
final SparkContext sc = new SparkContext(conf);
20-
jsc = new JavaSparkContext(sc);
21-
return sc;
22-
}
23-
2417
protected String stringify(String[] strings) {
2518
return "[" + StringUtils.join(strings, ",") + "]";
2619
}

connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaReadTest.java

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.apache.spark.api.java.JavaRDD;
2222
import org.apache.spark.sql.Dataset;
2323
import org.apache.spark.sql.Row;
24-
import org.apache.spark.sql.SQLContext;
2524
import org.apache.spark.sql.api.java.UDF1;
2625
import org.apache.spark.sql.functions;
2726
import org.apache.spark.sql.types.DataTypes;
@@ -64,17 +63,16 @@ public void readDataAsSqlRow() {
6463

6564
@Test
6665
public void riakTSRDDToDataFrame() {
67-
SQLContext sqlContext = new SQLContext(jsc);
6866
JavaRDD<TimeSeriesDataBean> rows = javaFunctions(jsc)
6967
.riakTSTable(bucketName(), Row.class)
7068
.sql(String.format("SELECT time, user_id, temperature_k FROM %s %s", bucketName(), sqlWhereClause()))
7169
.map(r -> new TimeSeriesDataBean(r.getTimestamp(0).getTime(), r.getString(1), r.getDouble(2)));
7270

73-
Dataset<Row> df = sqlContext.createDataFrame(rows, TimeSeriesDataBean.class);
74-
df.registerTempTable("test");
71+
Dataset<Row> df = sparkSession().createDataFrame(rows, TimeSeriesDataBean.class);
72+
df.createOrReplaceTempView("test");
7573

7674
// Explicit cast due to compilation error "Object cannot be converted to java.lang.String[]"
77-
String[] data = (String[]) sqlContext.sql("select * from test").toJSON().collect();
75+
String[] data = (String[]) sparkSession().sql("select * from test").toJSON().collect();
7876
assertEqualsUsingJSONIgnoreOrder("[" +
7977
"{time:111111, user_id:'bryce', temperature_k:305.37}," +
8078
"{time:111222, user_id:'bryce', temperature_k:300.12}," +
@@ -92,17 +90,16 @@ public void riakTSRDDToDataFrameConvertTimestamp() {
9290
DataTypes.createStructField("temperature_k", DataTypes.DoubleType, true),
9391
});
9492

95-
SQLContext sqlContext = new SQLContext(jsc);
9693
JavaRDD<TimeSeriesDataBean> rows = javaFunctions(jsc)
9794
.riakTSTable(bucketName(), structType, Row.class)
9895
.sql(String.format("SELECT time, user_id, temperature_k FROM %s %s", bucketName(), sqlWhereClause()))
9996
.map(r -> new TimeSeriesDataBean(r.getLong(0), r.getString(1), r.getDouble(2)));
10097

101-
Dataset<Row> df = sqlContext.createDataFrame(rows, TimeSeriesDataBean.class);
102-
df.registerTempTable("test");
98+
Dataset<Row> df = sparkSession().createDataFrame(rows, TimeSeriesDataBean.class);
99+
df.createOrReplaceTempView("test");
103100

104101
// Explicit cast due to compilation error "Object cannot be converted to java.lang.String[]"
105-
String[] data = (String[]) sqlContext.sql("select * from test").toJSON().collect();
102+
String[] data = (String[]) sparkSession().sql("select * from test").toJSON().collect();
106103
assertEqualsUsingJSONIgnoreOrder("[" +
107104
"{time:111111, user_id:'bryce', temperature_k:305.37}," +
108105
"{time:111222, user_id:'bryce', temperature_k:300.12}," +
@@ -114,11 +111,9 @@ public void riakTSRDDToDataFrameConvertTimestamp() {
114111

115112
@Test
116113
public void dataFrameGenericLoad() {
117-
SQLContext sqlContext = new SQLContext(jsc);
114+
sparkSession().udf().register("getMillis", (UDF1<Timestamp, Object>) Timestamp::getTime, DataTypes.LongType);
118115

119-
sqlContext.udf().register("getMillis", (UDF1<Timestamp, Object>) Timestamp::getTime, DataTypes.LongType);
120-
121-
Dataset<Row> df = sqlContext.read()
116+
Dataset<Row> df = sparkSession().read()
122117
.format("org.apache.spark.sql.riak")
123118
.schema(schema())
124119
.load(bucketName())
@@ -138,8 +133,6 @@ public void dataFrameGenericLoad() {
138133

139134
@Test
140135
public void dataFrameReadShouldConvertTimestampToLong() {
141-
SQLContext sqlContext = new SQLContext(jsc);
142-
143136
StructType structType = new StructType(new StructField[]{
144137
DataTypes.createStructField("surrogate_key", DataTypes.LongType, true),
145138
DataTypes.createStructField("family", DataTypes.StringType, true),
@@ -148,7 +141,7 @@ public void dataFrameReadShouldConvertTimestampToLong() {
148141
DataTypes.createStructField("temperature_k", DataTypes.DoubleType, true),
149142
});
150143

151-
Dataset<Row> df = sqlContext.read()
144+
Dataset<Row> df = sparkSession().read()
152145
.option("spark.riak.partitioning.ts-range-field-name", "time")
153146
.format("org.apache.spark.sql.riak")
154147
.schema(structType)
@@ -169,9 +162,7 @@ public void dataFrameReadShouldConvertTimestampToLong() {
169162

170163
@Test
171164
public void dataFrameReadShouldHandleTimestampAsLong() {
172-
SQLContext sqlContext = new SQLContext(jsc);
173-
174-
Dataset<Row> df = sqlContext.read()
165+
Dataset<Row> df = sparkSession().read()
175166
.format("org.apache.spark.sql.riak")
176167
.option("spark.riakts.bindings.timestamp", "useLong")
177168
.option("spark.riak.partitioning.ts-range-field-name", "time")

connector/src/test/java/com/basho/riak/spark/rdd/timeseries/TimeSeriesJavaWriteTest.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ public void saveSqlRowsToRiak() {
7171

7272
@Test
7373
public void saveDataFrameWithSchemaToRiak() {
74-
SQLContext sqlContext = new SQLContext(jsc);
75-
7674
JavaRDD<String> jsonRdd = jsc.parallelize(asList(
7775
"{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111111, \"user_id\": \"bryce\", \"temperature_k\": 305.37}",
7876
"{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111222, \"user_id\": \"bryce\", \"temperature_k\": 300.12}",
@@ -81,7 +79,7 @@ public void saveDataFrameWithSchemaToRiak() {
8179
"{\"surrogate_key\": 1, \"family\": \"f\", \"time\": 111555, \"user_id\": \"ratman\", \"temperature_k\": 3502.212}"
8280
));
8381

84-
Dataset<Row> df = sqlContext.read().schema(StructType$.MODULE$.apply(asScalaBuffer(asList(
82+
Dataset<Row> df = sparkSession().read().schema(StructType$.MODULE$.apply(asScalaBuffer(asList(
8583
DataTypes.createStructField("surrogate_key", DataTypes.IntegerType, true),
8684
DataTypes.createStructField("family", DataTypes.StringType, true),
8785
DataTypes.createStructField("time", DataTypes.LongType, true),

connector/src/test/scala/com/basho/riak/spark/rdd/AbstractRiakSparkTest.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ import scala.reflect.ClassTag
3131
import com.basho.riak.spark.rdd.AbstractRiakSparkTest._
3232
import com.basho.riak.spark.rdd.mapper.ReadValueDataMapper
3333
import org.apache.spark.SparkConf
34+
import org.apache.spark.sql.SparkSession
3435
import org.junit.ClassRule
3536

3637
import scala.collection.JavaConversions._
3738

3839

3940
abstract class AbstractRiakSparkTest extends AbstractRiakTest {
4041
// SparkContext, created per test case
42+
protected val sparkSession: SparkSession = createSparkSession(initSparkConf())
4143
protected var sc: SparkContext = _
4244

4345
protected override def riakHosts: Set[HostAndPort] = HostAndPort.hostsFromString(
@@ -55,10 +57,10 @@ abstract class AbstractRiakSparkTest extends AbstractRiakTest {
5557

5658
override def initialize(): Unit = {
5759
super.initialize()
58-
sc = createSparkContext(initSparkConf())
60+
sc = sparkSession.sparkContext
5961
}
6062

61-
protected def createSparkContext(conf: SparkConf): SparkContext = new SparkContext(conf)
63+
protected def createSparkSession(conf: SparkConf): SparkSession = SparkSession.builder().config(conf).getOrCreate()
6264

6365
@After
6466
def destroySparkContext(): Unit = Option(sc).foreach(x => x.stop())

connector/src/test/scala/com/basho/riak/spark/rdd/SparkDataframesTest.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package com.basho.riak.spark.rdd
1919

2020
import scala.reflect.runtime.universe
2121
import org.apache.spark.sql.DataFrame
22-
import org.apache.spark.sql.SQLContext
22+
import org.apache.spark.sql.SparkSession
2323
import org.junit.Assert._
2424
import org.junit.{ Before, Test }
2525
import com.basho.riak.spark.toSparkContextFunctions
@@ -44,17 +44,15 @@ class SparkDataframesTest extends AbstractRiakSparkTest {
4444

4545
protected override def initSparkConf() = super.initSparkConf().setAppName("Dataframes Test")
4646

47-
var sqlContextHolder: SQLContext = _
4847
var df: DataFrame = _
4948

5049
@Before
5150
def initializeDF(): Unit = {
52-
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
53-
import sqlContext.implicits._
54-
sqlContextHolder = sqlContext
51+
val spark = sparkSession
52+
import spark.implicits._
5553
df = sc.riakBucket[TestData](DEFAULT_NAMESPACE.getBucketNameAsString)
5654
.queryAll().toDF
57-
df.registerTempTable("test")
55+
df.createTempView("test")
5856
}
5957

6058
@Test
@@ -67,7 +65,7 @@ class SparkDataframesTest extends AbstractRiakSparkTest {
6765

6866
@Test
6967
def sqlQueryTest(): Unit = {
70-
val sqlResult = sqlContextHolder.sql("select * from test where category >= 'CategoryC'").toJSON.collect
68+
val sqlResult = sparkSession.sql("select * from test where category >= 'CategoryC'").toJSON.collect
7169
val expected =
7270
""" [
7371
| {id:'u4',name:'Chris',age:10,category:'CategoryC'},
@@ -78,8 +76,8 @@ class SparkDataframesTest extends AbstractRiakSparkTest {
7876

7977
@Test
8078
def udfTest(): Unit = {
81-
sqlContextHolder.udf.register("stringLength", (s: String) => s.length)
82-
val udf = sqlContextHolder.sql("select name, stringLength(name) strLgth from test order by strLgth, name").toJSON.collect
79+
sparkSession.udf.register("stringLength", (s: String) => s.length)
80+
val udf = sparkSession.sql("select name, stringLength(name) strLgth from test order by strLgth, name").toJSON.collect
8381
val expected =
8482
""" [
8583
| {name:'Ben',strLgth:3},
@@ -107,7 +105,7 @@ class SparkDataframesTest extends AbstractRiakSparkTest {
107105

108106
@Test
109107
def sqlVsFilterTest(): Unit = {
110-
val sql = sqlContextHolder.sql("select id, name from test where age >= 50").toJSON.collect
108+
val sql = sparkSession.sql("select id, name from test where age >= 50").toJSON.collect
111109
val filtered = df.where(df("age") >= 50).select("id", "name").toJSON.collect
112110
assertEqualsUsingJSONIgnoreOrder(stringify(sql), stringify(filtered))
113111
}

connector/src/test/scala/com/basho/riak/spark/rdd/SparkJobCompletionTest.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import com.basho.riak.client.core.query.Namespace
2323
import com.basho.riak.spark._
2424
import com.basho.riak.spark.rdd.SparkJobCompletionTest._
2525
import com.basho.riak.spark.rdd.connector.RiakConnectorConf
26+
import org.apache.spark.sql.SparkSession
2627
import org.apache.spark.{SparkConf, SparkContext}
2728
import org.junit.Test
2829
import org.junit.Assert
@@ -127,7 +128,8 @@ object SparkJobCompletionTest extends JsonFunctions {
127128
.set("spark.riak.connections.inactivity.timeout",
128129
(RiakConnectorConf.defaultInactivityTimeout * 60 * 5).toString) // 5 minutes is enough time to complete Spark job
129130

130-
val data = new SparkContext(sparkConf).riakBucket(ns).queryAll().collect()
131+
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
132+
val data = sparkSession.sparkContext.riakBucket(ns).queryAll().collect()
131133

132134
// HACK: Results should be printed for further analysis in the original JVM
133135
// to indicate that Spark job was completed successfully

0 commit comments

Comments
 (0)