Skip to content

Commit dad68ba

Browse files
committed
Corrected :*: member of DecomposableCovariance
1 parent bbf07cf commit dad68ba

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/DynaMLPipe.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,11 +581,7 @@ object DynaMLPipe {
581581
* Creates an [[Encoder]] which replicates a
582582
* [[DenseVector]] instance n times.
583583
* */
584-
val breezeDVReplicationEncoder = (n: Int) => Encoder((v: DenseVector[Double]) => {
585-
Array.fill(n)(v)
586-
}, (vs: Array[DenseVector[Double]]) => {
587-
vs.head
588-
})
584+
val breezeDVReplicationEncoder = (n: Int) => genericReplicationEncoder[DenseVector[Double]](n)
589585

590586
def trainParametricModel[
591587
G, T, Q, R, S, M <: ParameterizedLearner[G, T, Q, R, S]

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/kernels/LocalScalarKernel.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import io.github.mandar2812.dynaml.DynaMLPipe
55
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
66
import io.github.mandar2812.dynaml.pipes.{DataPipe, Encoder}
77

8+
import scala.reflect.ClassTag
9+
810
/**
911
* Scalar Kernel defines algebraic behavior for kernels of the form
1012
* K: Index x Index -> Double, i.e. kernel functions whose output
@@ -127,12 +129,20 @@ class DecomposableCovariance[S](kernels: LocalScalarKernel[S]*)(
127129
coupleAndKern._2.evaluate(u,v)
128130
}))
129131
}
132+
133+
override def gradient(x: S, y: S): Map[String, Double] = {
134+
val (xs, ys) = (encoding*encoding)((x,y))
135+
xs.zip(ys).zip(kernels).map(coupleAndKern => {
136+
val (u,v) = coupleAndKern._1
137+
coupleAndKern._2.gradient(u,v)
138+
}).reduceLeft(_++_)
139+
}
130140
}
131141

132142
object DecomposableCovariance {
133143

134144
val :+: = DataPipe((l: Array[Double]) => l.sum)
135145

136-
val :*: = DataPipe((l: Array[Double]) => l.sum)
146+
val :*: = DataPipe((l: Array[Double]) => l.product)
137147

138148
}

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/gp/AbstractGPRegressionModel.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import io.github.mandar2812.dynaml.optimization.GloballyOptWithGrad
3131
import io.github.mandar2812.dynaml.probability.MultGaussianPRV
3232
import org.apache.log4j.Logger
3333

34+
import scala.reflect.ClassTag
35+
3436
/**
3537
* Single-Output Gaussian Process Regression Model
3638
* Performs gp/spline smoothing/regression with
@@ -46,7 +48,7 @@ import org.apache.log4j.Logger
4648
*/
4749
abstract class AbstractGPRegressionModel[T, I](
4850
cov: LocalScalarKernel[I], n: LocalScalarKernel[I],
49-
data: T, num: Int)
51+
data: T, num: Int)(implicit ev: ClassTag[I])
5052
extends ContinuousProcess[T, I, Double, MultGaussianPRV]
5153
with SecondOrderProcess[T, I, Double, Double, DenseMatrix[Double], MultGaussianPRV]
5254
with GloballyOptWithGrad {

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/stp/AbstractSTPRegressionModel.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ import io.github.mandar2812.dynaml.probability.MultStudentsTPRV
3232
import io.github.mandar2812.dynaml.probability.distributions.{BlockedMultivariateStudentsT, MultivariateStudentsT}
3333
import org.apache.log4j.Logger
3434

35+
import scala.reflect.ClassTag
36+
3537
/**
3638
* @author mandar2812 date 26/08/16.
3739
* Implementation of a Students' T Regression model.
3840
*/
3941
abstract class AbstractSTPRegressionModel[T, I](
4042
mu: Double, cov: LocalScalarKernel[I],
4143
n: LocalScalarKernel[I],
42-
data: T, num: Int)
44+
data: T, num: Int)(implicit ev: ClassTag[I])
4345
extends ContinuousProcess[T, I, Double, MultStudentsTPRV]
4446
with SecondOrderProcess[T, I, Double, Double, DenseMatrix[Double], MultStudentsTPRV]
4547
with GloballyOptimizable {

0 commit comments

Comments
 (0)