diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala index 6869dd439..a962662eb 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala @@ -636,7 +636,14 @@ abstract class BasicDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](im override def withTargetPlatform(platform: Platform): This = { this.targetPlatforms += platform - this.asInstanceOf[This] + this.asInstanceOf[This] + } + + def getTargetPlatforms(): ListBuffer[Platform] = this.targetPlatforms + + def applyTargetPlatforms(op: DataQuanta[Out], targetPlatforms: ListBuffer[Platform]) = { + targetPlatforms.foreach(platform => op.withTargetPlatforms(platform)) + op } def withUdfJarOf(cls: Class[_]): This = this.withUdfJar(ReflectionUtils.getDeclaringJar(cls)) @@ -682,7 +689,7 @@ class UnarySourceDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](sourc (implicit javaPlanBuilder: JavaPlanBuilder) extends BasicDataQuantaBuilder[This, Out] { - override protected def build: DataQuanta[Out] = javaPlanBuilder.planBuilder.load(source)(this.classTag) + override protected def build: DataQuanta[Out] = applyTargetPlatforms(javaPlanBuilder.planBuilder.load(source)(this.classTag), this.getTargetPlatforms()) } @@ -747,7 +754,7 @@ class MapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In], this } - override protected def build = inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator) + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms()) } @@ -762,7 +769,7 @@ class ProjectionDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, (implicit javaPlanBuilder: JavaPlanBuilder) extends BasicDataQuantaBuilder[ProjectionDataQuantaBuilder[In, Out], Out] { - override protected def build = inputDataQuanta.dataQuanta().project(fieldNames.toSeq) + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().project(fieldNames.toSeq), this.getTargetPlatforms()) } @@ -834,9 +841,9 @@ class FilterDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], udf: this } - override protected def build = inputDataQuanta.dataQuanta().filterJava( + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().filterJava( udf, this.sqlUdf, this.selectivity, this.udfLoadProfileEstimator - ) + ), this.getTargetPlatforms()) } @@ -905,7 +912,7 @@ class SortDataQuantaBuilder[T, Key](inputDataQuanta: DataQuantaBuilder[_, T], } override protected def build = - inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag) + applyTargetPlatforms(inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag), this.getTargetPlatforms()) } @@ -968,9 +975,9 @@ class FlatMapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In this } - override protected def build = inputDataQuanta.dataQuanta().flatMapJava( + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().flatMapJava( udf, this.selectivity, this.udfLoadProfileEstimator - ) + ), this.getTargetPlatforms()) } @@ -1034,9 +1041,9 @@ class MapPartitionsDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder this } - override protected def build = inputDataQuanta.dataQuanta().mapPartitionsJava( + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapPartitionsJava( udf, this.selectivity, this.udfLoadProfileEstimator - ) + ), this.getTargetPlatforms()) } @@ -1102,7 +1109,7 @@ class SampleDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], sampl } override protected def build = - inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod) + applyTargetPlatforms(inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod), this.getTargetPlatforms()) } @@ -1166,9 +1173,7 @@ class ReduceByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T] this } - override protected def build = - inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator) - + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator), this.getTargetPlatforms()) } /** @@ -1214,8 +1219,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T], this } - override protected def build = - inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator) + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator), this.getTargetPlatforms()) } @@ -1227,7 +1231,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T], class GlobalGroupDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])(implicit javaPlanBuilder: JavaPlanBuilder) extends BasicDataQuantaBuilder[GlobalGroupDataQuantaBuilder[T], java.lang.Iterable[T]] { - override protected def build = inputDataQuanta.dataQuanta().group() + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().group(), this.getTargetPlatforms()) } @@ -1269,7 +1273,7 @@ class GlobalReduceDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], this } - override protected def build = inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator) + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms()) } @@ -1286,7 +1290,7 @@ class UnionDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T], override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap - override protected def build = inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta()) + override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms()) } @@ -1303,7 +1307,7 @@ class IntersectDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T], override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap - override protected def build = inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta()) + override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms()) } @@ -1427,7 +1431,7 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_ }) override protected def build = - inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag) + applyTargetPlatforms(inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms()) } @@ -1452,8 +1456,8 @@ class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[ } override protected def build = - inputDataQuanta0.dataQuanta() - .dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag) + applyTargetPlatforms(inputDataQuanta0.dataQuanta() + .dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms()) } /** @@ -1474,8 +1478,9 @@ class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_, } override protected def build = + applyTargetPlatforms( inputDataQuanta0.dataQuanta(). - predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass)) + predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass)), this.getTargetPlatforms()) } /** @@ -1587,7 +1592,7 @@ class CoGroupDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilde } override protected def build = - inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag) + applyTargetPlatforms(inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms()) } @@ -1609,7 +1614,7 @@ class CartesianDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_ } override protected def build = - inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag) + applyTargetPlatforms(inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms()) } @@ -1627,7 +1632,7 @@ class ZipWithIdDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T]) this.outputTypeTrap.dataSetType = dataSetType[RT2[_, _]] } - override protected def build = inputDataQuanta.dataQuanta().zipWithId + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().zipWithId, this.getTargetPlatforms()) } @@ -1643,7 +1648,7 @@ class DistinctDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T]) // Reuse the input TypeTrap to enforce type equality between input and output. override def getOutputTypeTrap: TypeTrap = inputDataQuanta.outputTypeTrap - override protected def build = inputDataQuanta.dataQuanta().distinct + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().distinct, this.getTargetPlatforms()) } @@ -1661,7 +1666,7 @@ class CountDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T]) this.outputTypeTrap.dataSetType = dataSetType[java.lang.Long] } - override protected def build = inputDataQuanta.dataQuanta().count + override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().count, this.getTargetPlatforms()) } @@ -1689,7 +1694,8 @@ class CustomOperatorDataQuantaBuilder[T](operator: Operator, val dataQuanta = javaPlanBuilder.planBuilder.customOperator(operator, inputDataQuanta.map(_.dataQuanta()): _*) buildCache.cache(dataQuanta) } - buildCache(outputIndex) + + applyTargetPlatforms(buildCache(outputIndex), this.getTargetPlatforms()) } } @@ -1767,10 +1773,10 @@ class DoWhileDataQuantaBuilder[T, ConvOut](inputDataQuanta: DataQuantaBuilder[_, this } - override protected def build = + override protected def build = applyTargetPlatforms( inputDataQuanta.dataQuanta().doWhileJava[ConvOut]( conditionUdf, dataQuantaBodyBuilder, this.numExpectedIterations, this.udfLoadProfileEstimator - )(this.convOutClassTag) + )(this.convOutClassTag), this.getTargetPlatforms()) /** @@ -1808,10 +1814,11 @@ class RepeatDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], // TODO: We could improve by combining the TypeTraps in the body loop. override protected def build = + applyTargetPlatforms( inputDataQuanta.dataQuanta().repeat(numRepetitions, startDataQuanta => { val loopStartbuilder = new FakeDataQuantaBuilder(startDataQuanta) bodyBuilder(loopStartbuilder).dataQuanta() - }) + }), this.getTargetPlatforms()) } @@ -1832,7 +1839,7 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil * * @return the created and partially configured [[DataQuanta]] */ - override protected def build: DataQuanta[T] = _dataQuanta + override protected def build: DataQuanta[T] = applyTargetPlatforms(_dataQuanta, this.getTargetPlatforms()) } /** @@ -1852,9 +1859,10 @@ class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_, } override protected def build: DataQuanta[LogisticRegressionModel] = + applyTargetPlatforms( inputDataQuanta0 .dataQuanta() - .trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept) + .trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept), this.getTargetPlatforms()) }