Skip to content

Commit fd818b6

Browse files
committed
MixtureMachine API
1 parent 45adb7b commit fd818b6

File tree

5 files changed

+37
-10
lines changed

5 files changed

+37
-10
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/modelpipe/MixturePipe.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ import breeze.stats.distributions.{ContinuousDistr, Moments}
55
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
66
import io.github.mandar2812.dynaml.models.gp.AbstractGPRegressionModel
77
import io.github.mandar2812.dynaml.models.stp.{AbstractSTPRegressionModel, MVStudentsTModel}
8-
import io.github.mandar2812.dynaml.models.{ContinuousProcessModel, GenContinuousMixtureModel, SecondOrderProcessModel, StochasticProcessMixtureModel}
8+
import io.github.mandar2812.dynaml.models.{
9+
ContinuousProcessModel, GenContinuousMixtureModel,
10+
SecondOrderProcessModel, StochasticProcessMixtureModel}
911
import io.github.mandar2812.dynaml.optimization.GloballyOptimizable
1012
import io.github.mandar2812.dynaml.pipes.DataPipe2
1113
import io.github.mandar2812.dynaml.probability.{ContinuousRVWithDistr, MatrixTRV, MultGaussianPRV, MultStudentsTPRV}
12-
import io.github.mandar2812.dynaml.probability.distributions.{BlockedMultiVariateGaussian, BlockedMultivariateStudentsT, HasErrorBars, MatrixT}
14+
import io.github.mandar2812.dynaml.probability.distributions.{
15+
BlockedMultiVariateGaussian, BlockedMultivariateStudentsT,
16+
HasErrorBars, MatrixT}
1317

1418
import scala.reflect.ClassTag
1519

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/gp/WarpedGPModel.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ import scala.reflect.ClassTag
1515

1616
/**
1717
* ::Experimental::
18-
* @author mandar date 02/01/2017.
1918
*
2019
* A warped Gaussian Process.
21-
*/
20+
* @author mandar date 02/01/2017.
21+
*
22+
* */
2223
@Experimental
2324
class WarpedGPModel[T, I:ClassTag](p: AbstractGPRegressionModel[T, I])(
2425
warpingFunc: PushforwardMap[Double, Double, Double])(
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ import scala.reflect.ClassTag
1616
* @tparam I The index set/input domain of the GP model.
1717
* @author mandar2812 date 15/06/2017.
1818
* */
19-
class ProbGPMixtureMachine[T, I: ClassTag](
19+
class GPMixtureMachine[T, I: ClassTag](
2020
model: AbstractGPRegressionModel[T, I]) extends
21-
MixtureMachine[T, I, Double, PartitionedVector, PartitionedPSDMatrix, BlockedMultiVariateGaussian,
22-
MultGaussianPRV, AbstractGPRegressionModel[T, I]](model) {
21+
MixtureMachine[
22+
T, I, Double, PartitionedVector, PartitionedPSDMatrix,
23+
BlockedMultiVariateGaussian, MultGaussianPRV,
24+
AbstractGPRegressionModel[T, I]](model) {
2325

2426
val (kernelPipe, noisePipe) = (system.covariance.asPipe, system.noiseModel.asPipe)
2527

@@ -33,7 +35,8 @@ class ProbGPMixtureMachine[T, I: ClassTag](
3335
(model_state: Map[String, Double]) =>
3436
AbstractGPRegressionModel(
3537
kernelPipe(model_state), noisePipe(model_state),
36-
system.mean)(system.data, system.npoints))
38+
system.mean)(system.data, system.npoints)
39+
)
3740

3841
override val mixturePipe = new GPMixturePipe[T, I]
3942

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/optimization/MixtureMachine.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,24 @@ BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
131131
)
132132
}
133133

134+
}
134135

136+
object MixtureMachine {
137+
138+
def apply[
139+
T, I: ClassTag, Y, YDomain, YDomainVar,
140+
BaseDistr <: ContinuousDistr[YDomain] with Moments[YDomain, YDomainVar] with HasErrorBars[YDomain],
141+
W1 <: ContinuousRVWithDistr[YDomain, BaseDistr],
142+
BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
143+
with SecondOrderProcessModel[T, I, Y, Double, DenseMatrix[Double], W1]
144+
with GloballyOptimizable](model: BaseProcess)(
145+
confModelPipe: DataPipe[Map[String, Double], BaseProcess],
146+
mixtPipe: DataPipe2[Seq[BaseProcess], DenseVector[Double], GenContinuousMixtureModel[
147+
T, I, Y, YDomain, YDomainVar,
148+
BaseDistr, W1, BaseProcess]]) =
149+
new MixtureMachine[T, I, Y, YDomain, YDomainVar, BaseDistr, W1, BaseProcess](model) {
150+
override val confToModel = confModelPipe
151+
override val mixturePipe = mixtPipe
152+
}
153+
135154
}

scripts/stochasticPriors.sc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import io.github.mandar2812.dynaml.models.bayes.{LinearTrendESGPrior, LinearTren
55
import io.github.mandar2812.dynaml.probability._
66
import com.quantifind.charts.Highcharts._
77
import io.github.mandar2812.dynaml.analysis.implicits._
8-
import io.github.mandar2812.dynaml.optimization.ProbGPMixtureMachine
8+
import io.github.mandar2812.dynaml.optimization.GPMixtureMachine
99
import io.github.mandar2812.dynaml.pipes.Encoder
1010
import io.github.mandar2812.dynaml.probability.distributions.UnivariateGaussian
1111

@@ -63,7 +63,7 @@ val sgpModel = sgp_prior.posteriorModel(dataset)
6363
gp_prior.globalOptConfig_(Map("gridStep" -> "0.0", "gridSize" -> "1", "globalOpt" -> "GS", "policy" -> "GS"))
6464
val gpModel1 = gp_prior.posteriorModel(dataset)
6565

66-
val mixt_machine = new ProbGPMixtureMachine(gpModel1).setPrior(hyp_prior).setGridSize(2).setStepSize(0.50).setLogScale(true).setMaxIterations(200).setNumSamples(3)
66+
val mixt_machine = new GPMixtureMachine(gpModel1).setPrior(hyp_prior).setGridSize(2).setStepSize(0.50).setLogScale(true).setMaxIterations(200).setNumSamples(3)
6767

6868
val (mix_model, mixt_model_conf) = mixt_machine.optimize(gp_prior.covariance.effective_state ++ gp_prior.noiseCovariance.effective_state)
6969

0 commit comments

Comments
 (0)