Skip to content

Commit 25d0f3f

Browse files
committed
Added multivariate gaussian scaler class
1 parent 78f3dff commit 25d0f3f

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed
Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
package io.github.mandar2812.dynaml.utils
22

3-
import breeze.linalg.DenseVector
3+
import breeze.linalg.{DenseMatrix, DenseVector, cholesky, inv}
44
import io.github.mandar2812.dynaml.pipes.{ReversibleScaler, Scaler}
55

66
/**
7-
* Created by mandar on 17/6/16.
7+
* @author mandar2812 date: 17/6/16.
8+
*
9+
* Scales attributes of a vector pattern using the sample mean and variance of
10+
* each dimension. This assumes that there is no covariance between the data
11+
* dimensions.
12+
*
13+
* @param mean Sample mean of the data
14+
* @param sigma Sample variance of each data dimension
815
*/
916
case class GaussianScaler(mean: DenseVector[Double], sigma: DenseVector[Double])
1017
extends ReversibleScaler[DenseVector[Double]]{
@@ -13,3 +20,23 @@ case class GaussianScaler(mean: DenseVector[Double], sigma: DenseVector[Double])
1320

1421
override def run(data: DenseVector[Double]): DenseVector[Double] = (data-mean) :/ sigma
1522
}
23+
24+
25+
/**
26+
* Scales the attributes of a data pattern using the sample mean and covariance matrix
27+
* calculated on the data set. This allows standardization of multivariate data sets
28+
* where the covariance of individual data dimensions is not negligible.
29+
*
30+
* @param mean Sample mean of data
31+
* @param sigma Sample covariance matrix of data.
32+
* */
33+
case class MVGaussianScaler(mean: DenseVector[Double], sigma: DenseMatrix[Double])
34+
extends ReversibleScaler[DenseVector[Double]] {
35+
36+
val sigmaInverse = cholesky(inv(sigma))
37+
38+
override val i: Scaler[DenseVector[Double]] =
39+
Scaler((pattern: DenseVector[Double]) => (inv(sigmaInverse.t) * pattern) + mean)
40+
41+
override def run(data: DenseVector[Double]): DenseVector[Double] = sigmaInverse.t * (data - mean)
42+
}

0 commit comments

Comments
 (0)