Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,14 @@ abstract class BasicDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](im

override def withTargetPlatform(platform: Platform): This = {
this.targetPlatforms += platform
this.asInstanceOf[This]
this.asInstanceOf[This]
}

def getTargetPlatforms(): ListBuffer[Platform] = this.targetPlatforms

def applyTargetPlatforms(op: DataQuanta[Out], targetPlatforms: ListBuffer[Platform]) = {
targetPlatforms.foreach(platform => op.withTargetPlatforms(platform))
op
}

def withUdfJarOf(cls: Class[_]): This = this.withUdfJar(ReflectionUtils.getDeclaringJar(cls))
Expand Down Expand Up @@ -682,7 +689,7 @@ class UnarySourceDataQuantaBuilder[This <: DataQuantaBuilder[_, Out], Out](sourc
(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[This, Out] {

override protected def build: DataQuanta[Out] = javaPlanBuilder.planBuilder.load(source)(this.classTag)
override protected def build: DataQuanta[Out] = applyTargetPlatforms(javaPlanBuilder.planBuilder.load(source)(this.classTag), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -747,7 +754,7 @@ class MapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In],
this
}

override protected def build = inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator)
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())

}

Expand All @@ -762,7 +769,7 @@ class ProjectionDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_,
(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[ProjectionDataQuantaBuilder[In, Out], Out] {

override protected def build = inputDataQuanta.dataQuanta().project(fieldNames.toSeq)
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().project(fieldNames.toSeq), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -834,9 +841,9 @@ class FilterDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], udf:
this
}

override protected def build = inputDataQuanta.dataQuanta().filterJava(
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().filterJava(
udf, this.sqlUdf, this.selectivity, this.udfLoadProfileEstimator
)
), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -905,7 +912,7 @@ class SortDataQuantaBuilder[T, Key](inputDataQuanta: DataQuantaBuilder[_, T],
}

override protected def build =
inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag)
applyTargetPlatforms(inputDataQuanta.dataQuanta().sortJava(keyUdf)(this.keyTag), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -968,9 +975,9 @@ class FlatMapDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder[_, In
this
}

override protected def build = inputDataQuanta.dataQuanta().flatMapJava(
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().flatMapJava(
udf, this.selectivity, this.udfLoadProfileEstimator
)
), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1034,9 +1041,9 @@ class MapPartitionsDataQuantaBuilder[In, Out](inputDataQuanta: DataQuantaBuilder
this
}

override protected def build = inputDataQuanta.dataQuanta().mapPartitionsJava(
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().mapPartitionsJava(
udf, this.selectivity, this.udfLoadProfileEstimator
)
), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1102,7 +1109,7 @@ class SampleDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T], sampl
}

override protected def build =
inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod)
applyTargetPlatforms(inputDataQuanta.dataQuanta().sampleDynamicJava(sampleSizeFunction, this.datasetSize, this.seed, this.sampleMethod), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1166,9 +1173,7 @@ class ReduceByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T]
this
}

override protected def build =
inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator)

override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())
}

/**
Expand Down Expand Up @@ -1214,8 +1219,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T],
this
}

override protected def build =
inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator)
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().groupByKeyJava(keyUdf, this.keyUdfLoadProfileEstimator), this.getTargetPlatforms())

}

Expand All @@ -1227,7 +1231,7 @@ class GroupByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T],
class GlobalGroupDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[GlobalGroupDataQuantaBuilder[T], java.lang.Iterable[T]] {

override protected def build = inputDataQuanta.dataQuanta().group()
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().group(), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1269,7 +1273,7 @@ class GlobalReduceDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
this
}

override protected def build = inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator)
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceJava(udf, this.udfLoadProfileEstimator), this.getTargetPlatforms())

}

Expand All @@ -1286,7 +1290,7 @@ class UnionDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T],

override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap

override protected def build = inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta())
override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().union(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms())

}

Expand All @@ -1303,7 +1307,7 @@ class IntersectDataQuantaBuilder[T](inputDataQuanta0: DataQuantaBuilder[_, T],

override def getOutputTypeTrap = inputDataQuanta0.outputTypeTrap

override protected def build = inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta())
override protected def build = applyTargetPlatforms(inputDataQuanta0.dataQuanta().intersect(inputDataQuanta1.dataQuanta()), this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1427,7 +1431,7 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
})

override protected def build =
inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag)
applyTargetPlatforms(inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms())

}

Expand All @@ -1452,8 +1456,8 @@ class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[
}

override protected def build =
inputDataQuanta0.dataQuanta()
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
applyTargetPlatforms(inputDataQuanta0.dataQuanta()
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms())
}

/**
Expand All @@ -1474,8 +1478,9 @@ class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_,
}

override protected def build =
applyTargetPlatforms(
inputDataQuanta0.dataQuanta().
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass))
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass)), this.getTargetPlatforms())
}

/**
Expand Down Expand Up @@ -1587,7 +1592,7 @@ class CoGroupDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilde
}

override protected def build =
inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag)
applyTargetPlatforms(inputDataQuanta0.dataQuanta().coGroupJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this.keyTag), this.getTargetPlatforms())

}

Expand All @@ -1609,7 +1614,7 @@ class CartesianDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_
}

override protected def build =
inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
applyTargetPlatforms(inputDataQuanta0.dataQuanta().cartesian(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag), this.getTargetPlatforms())

}

Expand All @@ -1627,7 +1632,7 @@ class ZipWithIdDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
this.outputTypeTrap.dataSetType = dataSetType[RT2[_, _]]
}

override protected def build = inputDataQuanta.dataQuanta().zipWithId
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().zipWithId, this.getTargetPlatforms())

}

Expand All @@ -1643,7 +1648,7 @@ class DistinctDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
// Reuse the input TypeTrap to enforce type equality between input and output.
override def getOutputTypeTrap: TypeTrap = inputDataQuanta.outputTypeTrap

override protected def build = inputDataQuanta.dataQuanta().distinct
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().distinct, this.getTargetPlatforms())

}

Expand All @@ -1661,7 +1666,7 @@ class CountDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T])
this.outputTypeTrap.dataSetType = dataSetType[java.lang.Long]
}

override protected def build = inputDataQuanta.dataQuanta().count
override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().count, this.getTargetPlatforms())

}

Expand Down Expand Up @@ -1689,7 +1694,8 @@ class CustomOperatorDataQuantaBuilder[T](operator: Operator,
val dataQuanta = javaPlanBuilder.planBuilder.customOperator(operator, inputDataQuanta.map(_.dataQuanta()): _*)
buildCache.cache(dataQuanta)
}
buildCache(outputIndex)

applyTargetPlatforms(buildCache(outputIndex), this.getTargetPlatforms())
}

}
Expand Down Expand Up @@ -1767,10 +1773,10 @@ class DoWhileDataQuantaBuilder[T, ConvOut](inputDataQuanta: DataQuantaBuilder[_,
this
}

override protected def build =
override protected def build = applyTargetPlatforms(
inputDataQuanta.dataQuanta().doWhileJava[ConvOut](
conditionUdf, dataQuantaBodyBuilder, this.numExpectedIterations, this.udfLoadProfileEstimator
)(this.convOutClassTag)
)(this.convOutClassTag), this.getTargetPlatforms())


/**
Expand Down Expand Up @@ -1808,10 +1814,11 @@ class RepeatDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
// TODO: We could improve by combining the TypeTraps in the body loop.

override protected def build =
applyTargetPlatforms(
inputDataQuanta.dataQuanta().repeat(numRepetitions, startDataQuanta => {
val loopStartbuilder = new FakeDataQuantaBuilder(startDataQuanta)
bodyBuilder(loopStartbuilder).dataQuanta()
})
}), this.getTargetPlatforms())

}

Expand All @@ -1832,7 +1839,7 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil
*
* @return the created and partially configured [[DataQuanta]]
*/
override protected def build: DataQuanta[T] = _dataQuanta
override protected def build: DataQuanta[T] = applyTargetPlatforms(_dataQuanta, this.getTargetPlatforms())
}

/**
Expand All @@ -1852,9 +1859,10 @@ class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_,
}

override protected def build: DataQuanta[LogisticRegressionModel] =
applyTargetPlatforms(
inputDataQuanta0
.dataQuanta()
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept)
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept), this.getTargetPlatforms())


}
Expand Down
Loading