Skip to content

Commit a2471de

Browse files
committed
Added trend encoders to ESGP priors
1 parent 140ca49 commit a2471de

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/bayes/ESGPPrior.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import io.github.mandar2812.dynaml.analysis.PartitionedVectorField
66
import io.github.mandar2812.dynaml.kernels.LocalScalarKernel
77
import io.github.mandar2812.dynaml.modelpipe.ESGPPipe4
88
import io.github.mandar2812.dynaml.models.sgp.ESGPModel
9-
import io.github.mandar2812.dynaml.pipes.{DataPipe, MetaPipe}
9+
import io.github.mandar2812.dynaml.pipes.{DataPipe, Encoder, MetaPipe}
1010
import io.github.mandar2812.dynaml.probability.BlockedMESNRV
1111
import spire.algebra.{Field, InnerProductSpace}
1212
import io.github.mandar2812.dynaml.algebra.PartitionedMatrixOps._
@@ -44,6 +44,8 @@ abstract class ESGPPrior[I: ClassTag, MeanFuncParams](
4444

4545
val meanFunctionPipe: MetaPipe[MeanFuncParams, I, Double]
4646

47+
val trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]]
48+
4749
private var globalOptConfig = Map(
4850
"globalOpt" -> "GS",
4951
"gridSize" -> "3",
@@ -107,14 +109,14 @@ abstract class ESGPPrior[I: ClassTag, MeanFuncParams](
107109
}
108110

109111
/**
110-
* @author mandar2812 date 21/02/2017.
111-
*
112112
* An extended skew gaussian process prior with a
113113
* linear trend function.
114+
*
115+
* @author mandar2812 date 21/02/2017.
114116
* */
115117
class LinearTrendESGPrior[I: ClassTag](
116-
cov: LocalScalarKernel[I],
117-
n: LocalScalarKernel[I],
118+
cov: LocalScalarKernel[I], n: LocalScalarKernel[I],
119+
override val trendParamsEncoder: Encoder[(I, Double), Map[String, Double]],
118120
lambda: Double, tau: Double,
119121
trendParams: I, intercept: Double)(
120122
implicit inner: InnerProductSpace[I, Double]) extends

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/bayes/GaussianProcessPrior.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ import io.github.mandar2812.dynaml.probability.{MatrixNormalRV, MultGaussianPRV}
4646
* */
4747
abstract class GaussianProcessPrior[I: ClassTag, MeanFuncParams](
4848
val covariance: LocalScalarKernel[I],
49-
val noiseCovariance: LocalScalarKernel[I],
50-
val trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]]) extends
49+
val noiseCovariance: LocalScalarKernel[I]) extends
5150
StochasticProcessPrior[
5251
I, Double, PartitionedVector,
5352
MultGaussianPRV, MultGaussianPRV,
@@ -61,6 +60,8 @@ abstract class GaussianProcessPrior[I: ClassTag, MeanFuncParams](
6160

6261
def meanFuncParams_(p: MeanFuncParams): Unit
6362

63+
val trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]]
64+
6465
protected val initial_covariance_state: Map[String, Double] = covariance.state ++ noiseCovariance.state
6566

6667
val meanFunctionPipe: MetaPipe[MeanFuncParams, I, Double]
@@ -153,9 +154,9 @@ object GaussianProcessPrior {
153154
covariance: LocalScalarKernel[I],
154155
noiseCovariance: LocalScalarKernel[I],
155156
meanFPipe: MetaPipe[MeanFuncParams, I, Double],
156-
trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]],
157+
trendEncoder: Encoder[MeanFuncParams, Map[String, Double]],
157158
initialParams: MeanFuncParams): GaussianProcessPrior[I, MeanFuncParams] =
158-
new GaussianProcessPrior[I, MeanFuncParams](covariance, noiseCovariance, trendParamsEncoder) {
159+
new GaussianProcessPrior[I, MeanFuncParams](covariance, noiseCovariance) {
159160

160161
private var params = initialParams
161162

@@ -164,6 +165,8 @@ object GaussianProcessPrior {
164165
override def meanFuncParams_(p: MeanFuncParams) = params = p
165166

166167
override val meanFunctionPipe = meanFPipe
168+
169+
override val trendParamsEncoder = trendEncoder
167170
}
168171

169172
}
@@ -176,10 +179,10 @@ object GaussianProcessPrior {
176179
* */
177180
class LinearTrendGaussianPrior[I: ClassTag](
178181
cov: LocalScalarKernel[I], n: LocalScalarKernel[I],
179-
trendParamsEncoder: Encoder[(I, Double), Map[String, Double]],
182+
override val trendParamsEncoder: Encoder[(I, Double), Map[String, Double]],
180183
trendParams: I, intercept: Double)(
181184
implicit inner: InnerProductSpace[I, Double]) extends
182-
GaussianProcessPrior[I, (I, Double)](cov, n, trendParamsEncoder) with
185+
GaussianProcessPrior[I, (I, Double)](cov, n) with
183186
LinearTrendStochasticPrior[I, MultGaussianPRV, MultGaussianPRV, AbstractGPRegressionModel[Seq[(I, Double)], I]]{
184187

185188
override val innerProduct = inner
@@ -193,6 +196,7 @@ class LinearTrendGaussianPrior[I: ClassTag](
193196
override val meanFunctionPipe = MetaPipe(
194197
(parameters: (I, Double)) => (x: I) => inner.dot(parameters._1, x) + parameters._2
195198
)
199+
196200
}
197201

198202
/**
@@ -215,11 +219,10 @@ class LinearTrendGaussianPrior[I: ClassTag](
215219
abstract class CoRegGPPrior[I: ClassTag, J: ClassTag, MeanFuncParams](
216220
covarianceI: LocalScalarKernel[I], covarianceJ: LocalScalarKernel[J],
217221
noiseCovarianceI: LocalScalarKernel[I], noiseCovarianceJ: LocalScalarKernel[J],
218-
trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]]) extends
222+
override val trendParamsEncoder: Encoder[MeanFuncParams, Map[String, Double]]) extends
219223
GaussianProcessPrior[(I,J), MeanFuncParams](
220224
covarianceI:*covarianceJ,
221-
noiseCovarianceI:*noiseCovarianceJ,
222-
trendParamsEncoder) {
225+
noiseCovarianceI:*noiseCovarianceJ) {
223226

224227
self =>
225228

scripts/stochasticPriors.sc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ val n = new MAKernel(0.8)
3131
val gp_prior = new LinearTrendGaussianPrior[Double](gsmKernel, n, trendEncoder, 0.0, 0.0)
3232
//gp_prior.hyperPrior_(hyp_prior)
3333

34-
val sgp_prior = new LinearTrendESGPrior[Double](gsmKernel, n, 0.75, 0.1, 0.0, 0.0)
34+
val sgp_prior = new LinearTrendESGPrior[Double](gsmKernel, n, trendEncoder, 0.75, 0.1, 0.0, 0.0)
3535
sgp_prior.hyperPrior_(sgp_hyp_prior)
3636

3737
val xs = Seq.tabulate[Double](20)(0.5*_)

0 commit comments

Comments
 (0)