Skip to content

Commit 471bdc5

Browse files
committed
Added class for Mean centering
1 parent 95ebe6f commit 471bdc5

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/DynaMLPipe.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import io.github.mandar2812.dynaml.models.sgp.ESGPModel
2828
import io.github.mandar2812.dynaml.optimization._
2929
import io.github.mandar2812.dynaml.pipes._
3030
import io.github.mandar2812.dynaml.probability.ContinuousDistrRV
31-
import io.github.mandar2812.dynaml.utils.{GaussianScaler, MVGaussianScaler, MinMaxScaler}
31+
import io.github.mandar2812.dynaml.utils.{GaussianScaler, MVGaussianScaler, MeanScaler, MinMaxScaler}
3232
import io.github.mandar2812.dynaml.wavelets.{GroupedHaarWaveletFilter, HaarWaveletFilter, InvGroupedHaarWaveletFilter, InverseHaarWaveletFilter}
3333
import org.apache.log4j.Logger
3434
import org.apache.spark.rdd.RDD
@@ -383,6 +383,31 @@ object DynaMLPipe {
383383
(result, (featuresScaler, targetsScaler))
384384
})
385385

386+
/**
387+
* Returns a pipe which takes a data set and mean centers it.
388+
* @param standardize Set to true if one wants the standardized data and false if one
389+
* does wants the original data with the [[MeanScaler]] instances.
390+
* */
391+
def calculateMeanScales(standardize: Boolean = true): DataPipe[
392+
Stream[(DenseVector[Double], DenseVector[Double])],
393+
(Stream[(DenseVector[Double], DenseVector[Double])], (MeanScaler, MeanScaler))] =
394+
DataPipe((data: Stream[(DenseVector[Double], DenseVector[Double])]) => {
395+
396+
val (num_features, num_targets) = (data.head._1.length, data.head._2.length)
397+
398+
val (mean, _) = utils.getStats(data.map(tup =>
399+
DenseVector(tup._1.toArray ++ tup._2.toArray)).toList)
400+
401+
val featuresScaler = MeanScaler(mean(0 until num_features))
402+
403+
val targetsScaler = MeanScaler(mean(num_features until num_features + num_targets))
404+
405+
val result = if(standardize) (featuresScaler * targetsScaler)(data) else data
406+
407+
(result, (featuresScaler, targetsScaler))
408+
})
409+
410+
386411
/**
387412
* Multivariate version of [[calculateGaussianScales]]
388413
* @param standardize Set to true if one wants the standardized data and false if one

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/utils/GaussianScaler.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ import breeze.linalg.{DenseMatrix, DenseVector, cholesky, inv}
2222
import io.github.mandar2812.dynaml.pipes.{ReversibleScaler, Scaler}
2323

2424
/**
25-
* @author mandar2812 date: 17/6/16.
26-
*
2725
* Scales attributes of a vector pattern using the sample mean and variance of
2826
* each dimension. This assumes that there is no covariance between the data
2927
* dimensions.
3028
*
3129
* @param mean Sample mean of the data
3230
* @param sigma Sample variance of each data dimension
33-
*/
31+
* @author mandar2812 date: 17/6/16.
32+
*
33+
* */
3434
case class GaussianScaler(mean: DenseVector[Double], sigma: DenseVector[Double])
3535
extends ReversibleScaler[DenseVector[Double]]{
3636
override val i: Scaler[DenseVector[Double]] =
37-
Scaler((pattern: DenseVector[Double]) => (pattern :* sigma) + mean)
37+
Scaler((pattern: DenseVector[Double]) => (pattern *:* sigma) + mean)
3838

3939
override def run(data: DenseVector[Double]): DenseVector[Double] = (data-mean) :/ sigma
4040

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package io.github.mandar2812.dynaml.utils
2+
3+
import breeze.linalg.DenseVector
4+
import io.github.mandar2812.dynaml.pipes.{ReversibleScaler, Scaler}
5+
6+
/**
7+
* @author mandar date 30/05/2017.
8+
* */
9+
case class MeanScaler(center: DenseVector[Double]) extends ReversibleScaler[DenseVector[Double]] {
10+
11+
override val i = Scaler((data: DenseVector[Double]) => data + center)
12+
13+
override def run(data: DenseVector[Double]) = data - center
14+
15+
def apply(r: Range): MeanScaler = MeanScaler(center(r))
16+
17+
def apply(n: Int): UnivariateMeanScaler = UnivariateMeanScaler(center(n))
18+
19+
}
20+
21+
case class UnivariateMeanScaler(center: Double) extends ReversibleScaler[Double] {
22+
23+
override val i = Scaler((data: Double) => data + center)
24+
25+
override def run(data: Double) = data - center
26+
}

0 commit comments

Comments
 (0)