Skip to content

Commit d6b58cd

Browse files
authored
Merge pull request #635 from mspruc/627-incorrect-printout-of-execution-plan-in-explain
pass TargetPlatform when building with DataQuantaBuilder
2 parents 1e9e96c + d8669bc commit d6b58cd

File tree

1 file changed

+44
-36
lines changed

1 file changed

+44
-36
lines changed

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,14 @@ abstract class BasicDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](im
636636

637637
override def withTargetPlatform(platform: Platform): This = {
638638
this.targetPlatforms += platform
639-
this.asInstanceOf[This]
639+
this.asInstanceOf[This]
640+
}
641+
642+
def getTargetPlatforms(): ListBuffer[Platform] = this.targetPlatforms
643+
644+
def applyTargetPlatforms(op: DataQuanta[Out], targetPlatforms: ListBuffer[Platform]) = {
645+
targetPlatforms.foreach(platform => op.withTargetPlatforms(platform))
646+
op
640647
}
641648

642649
def withUdfJarOf(cls: Class[_]): This = this.withUdfJar(ReflectionUtils.getDeclaringJar(cls))
@@ -690,7 +697,7 @@ class UnarySourceDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](sourc
690697
(implicit javaPlanBuilder: JavaPlanBuilder)
691698
extends BasicDataQuantaBuilder[This, Out] {
692699

693-
override protected def build: DataQuanta[Out] = javaPlanBuilder.planBuilder.load(source)(this.classTag)
700+
override protected def build: DataQuanta[Out] = applyTargetPlatforms(javaPlanBuilder.planBuilder.load(source)(this.classTag), this.getTargetPlatforms())
694701

695702
}
696703

@@ -755,7 +762,7 @@ class MapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In],
755762
this
756763
}
757764

758-
override protected def build = inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator)
765+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())
759766

760767
}
761768

@@ -770,7 +777,7 @@ class ProjectionDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_,
770777
(implicit javaPlanBuilder: JavaPlanBuilder)
771778
extends BasicDataQuantaBuilder[ProjectionDataQuantaBuilder[In, Out], Out] {
772779

773-
override protected def build = inputDataQuanta.dataQuanta().project(fieldNames.toSeq)
780+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().project(fieldNames.toSeq), this.getTargetPlatforms())
774781

775782
}
776783

@@ -842,9 +849,9 @@ class FilterDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], udf:
842849
this
843850
}
844851

845-
override protected def build = inputDataQuanta.dataQuanta().filterJava(
852+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().filterJava(
846853
udf, this.sqlUdf, this.selectivity, this.udfLoadProfileEstimator
847-
)
854+
), this.getTargetPlatforms())
848855

849856
}
850857

@@ -913,7 +920,7 @@ class SortDataQuantaBuilder[T, Key](inputDataQuanta: DataQuantaBuilder[_, T],
913920
}
914921

915922
override protected def build =
916-
inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag)
923+
applyTargetPlatforms(inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag), this.getTargetPlatforms())
917924

918925
}
919926

@@ -976,9 +983,9 @@ class FlatMapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In
976983
this
977984
}
978985

979-
override protected def build = inputDataQuanta.dataQuanta().flatMapJava(
986+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().flatMapJava(
980987
udf, this.selectivity, this.udfLoadProfileEstimator
981-
)
988+
), this.getTargetPlatforms())
982989

983990
}
984991

@@ -1042,9 +1049,9 @@ class MapPartitionsDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder
10421049
this
10431050
}
10441051

1045-
override protected def build = inputDataQuanta.dataQuanta().mapPartitionsJava(
1052+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapPartitionsJava(
10461053
udf, this.selectivity, this.udfLoadProfileEstimator
1047-
)
1054+
), this.getTargetPlatforms())
10481055

10491056
}
10501057

@@ -1110,7 +1117,7 @@ class SampleDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], sampl
11101117
}
11111118

11121119
override protected def build =
1113-
inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod)
1120+
applyTargetPlatforms(inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod), this.getTargetPlatforms())
11141121

11151122
}
11161123

@@ -1174,9 +1181,7 @@ class ReduceByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T]
11741181
this
11751182
}
11761183

1177-
override protected def build =
1178-
inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator)
1179-
1184+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())
11801185
}
11811186

11821187
/**
@@ -1222,8 +1227,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T],
12221227
this
12231228
}
12241229

1225-
override protected def build =
1226-
inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator)
1230+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator), this.getTargetPlatforms())
12271231

12281232
}
12291233

@@ -1235,7 +1239,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T],
12351239
class GlobalGroupDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])(implicit javaPlanBuilder: JavaPlanBuilder)
12361240
extends BasicDataQuantaBuilder[GlobalGroupDataQuantaBuilder[T], java.lang.Iterable[T]] {
12371241

1238-
override protected def build = inputDataQuanta.dataQuanta().group()
1242+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().group(), this.getTargetPlatforms())
12391243

12401244
}
12411245

@@ -1277,7 +1281,7 @@ class GlobalReduceDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
12771281
this
12781282
}
12791283

1280-
override protected def build = inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator)
1284+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())
12811285

12821286
}
12831287

@@ -1294,7 +1298,7 @@ class UnionDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T],
12941298

12951299
override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap
12961300

1297-
override protected def build = inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta())
1301+
override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms())
12981302

12991303
}
13001304

@@ -1311,7 +1315,7 @@ class IntersectDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T],
13111315

13121316
override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap
13131317

1314-
override protected def build = inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta())
1318+
override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms())
13151319

13161320
}
13171321

@@ -1435,7 +1439,7 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
14351439
})
14361440

14371441
override protected def build =
1438-
inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag)
1442+
applyTargetPlatforms(inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms())
14391443

14401444
}
14411445

@@ -1460,8 +1464,8 @@ class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[
14601464
}
14611465

14621466
override protected def build =
1463-
inputDataQuanta0.dataQuanta()
1464-
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
1467+
applyTargetPlatforms(inputDataQuanta0.dataQuanta()
1468+
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms())
14651469
}
14661470

14671471
/**
@@ -1482,8 +1486,9 @@ class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_,
14821486
}
14831487

14841488
override protected def build =
1489+
applyTargetPlatforms(
14851490
inputDataQuanta0.dataQuanta().
1486-
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass))
1491+
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass)), this.getTargetPlatforms())
14871492
}
14881493

14891494
/**
@@ -1595,7 +1600,7 @@ class CoGroupDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilde
15951600
}
15961601

15971602
override protected def build =
1598-
inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag)
1603+
applyTargetPlatforms(inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms())
15991604

16001605
}
16011606

@@ -1617,7 +1622,7 @@ class CartesianDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_
16171622
}
16181623

16191624
override protected def build =
1620-
inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
1625+
applyTargetPlatforms(inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms())
16211626

16221627
}
16231628

@@ -1635,7 +1640,7 @@ class ZipWithIdDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
16351640
this.outputTypeTrap.dataSetType = dataSetType[RT2[_, _]]
16361641
}
16371642

1638-
override protected def build = inputDataQuanta.dataQuanta().zipWithId
1643+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().zipWithId, this.getTargetPlatforms())
16391644

16401645
}
16411646

@@ -1651,7 +1656,7 @@ class DistinctDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
16511656
// Reuse the input TypeTrap to enforce type equality between input and output.
16521657
override def getOutputTypeTrap: TypeTrap = inputDataQuanta.outputTypeTrap
16531658

1654-
override protected def build = inputDataQuanta.dataQuanta().distinct
1659+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().distinct, this.getTargetPlatforms())
16551660

16561661
}
16571662

@@ -1669,7 +1674,7 @@ class CountDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
16691674
this.outputTypeTrap.dataSetType = dataSetType[java.lang.Long]
16701675
}
16711676

1672-
override protected def build = inputDataQuanta.dataQuanta().count
1677+
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().count, this.getTargetPlatforms())
16731678

16741679
}
16751680

@@ -1697,7 +1702,8 @@ class CustomOperatorDataQuantaBuilder[T](operator: Operator,
16971702
val dataQuanta = javaPlanBuilder.planBuilder.customOperator(operator, inputDataQuanta.map(_.dataQuanta()): _*)
16981703
buildCache.cache(dataQuanta)
16991704
}
1700-
buildCache(outputIndex)
1705+
1706+
applyTargetPlatforms(buildCache(outputIndex), this.getTargetPlatforms())
17011707
}
17021708

17031709
}
@@ -1775,10 +1781,10 @@ class DoWhileDataQuantaBuilder[T, ConvOut](inputDataQuanta: DataQuantaBuilder[_,
17751781
this
17761782
}
17771783

1778-
override protected def build =
1784+
override protected def build = applyTargetPlatforms(
17791785
inputDataQuanta.dataQuanta().doWhileJava[ConvOut](
17801786
conditionUdf, dataQuantaBodyBuilder, this.numExpectedIterations, this.udfLoadProfileEstimator
1781-
)(this.convOutClassTag)
1787+
)(this.convOutClassTag), this.getTargetPlatforms())
17821788

17831789

17841790
/**
@@ -1816,10 +1822,11 @@ class RepeatDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
18161822
// TODO: We could improve by combining the TypeTraps in the body loop.
18171823

18181824
override protected def build =
1825+
applyTargetPlatforms(
18191826
inputDataQuanta.dataQuanta().repeat(numRepetitions, startDataQuanta => {
18201827
val loopStartbuilder = new FakeDataQuantaBuilder(startDataQuanta)
18211828
bodyBuilder(loopStartbuilder).dataQuanta()
1822-
})
1829+
}), this.getTargetPlatforms())
18231830

18241831
}
18251832

@@ -1840,7 +1847,7 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil
18401847
*
18411848
* @return the created and partially configured [[DataQuanta]]
18421849
*/
1843-
override protected def build: DataQuanta[T] = _dataQuanta
1850+
override protected def build: DataQuanta[T] = applyTargetPlatforms(_dataQuanta, this.getTargetPlatforms())
18441851
}
18451852

18461853
/**
@@ -1860,9 +1867,10 @@ class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_,
18601867
}
18611868

18621869
override protected def build: DataQuanta[LogisticRegressionModel] =
1870+
applyTargetPlatforms(
18631871
inputDataQuanta0
18641872
.dataQuanta()
1865-
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept)
1873+
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept), this.getTargetPlatforms())
18661874

18671875

18681876
}

0 commit comments

Comments
 (0)