Skip to content

Commit 3378a31

Browse files
committed
[WIP]: Decomposable kernels
1 parent 489c726 commit 3378a31

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/kernels/LocalScalarKernel.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.github.mandar2812.dynaml.kernels
22

33
import breeze.linalg.DenseMatrix
44
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
5+
import io.github.mandar2812.dynaml.pipes.{DataPipe, Encoder}
56

67
/**
78
* Scalar Kernel defines algebraic behavior for kernels of the form
@@ -79,4 +80,40 @@ abstract class CompositeCovariance[T]
7980
*
8081
* for example K((x1, y1), (x1, y2)) = k1(x1,x2) + k2(y1,y2)
8182
*/
82-
trait DecomposableCovariance extends CompositeCovariance[PartitionedVector]
83+
class DecomposableCovariance[S](kernels: LocalScalarKernel[S]*)(
84+
implicit encoding: Encoder[S, Array[S]],
85+
reducer: DataPipe[Array[Double], Double]) extends CompositeCovariance[S] {
86+
87+
val kernelMap = kernels.map(k => (k.toString.split(".").last, k)).toMap
88+
89+
override val hyper_parameters: List[String] = kernels.map(k => {
90+
val id = k.toString.split(".").last
91+
k.hyper_parameters.map(h => id+"/"+h)
92+
}).reduceLeft(_++_)
93+
94+
blocked_hyper_parameters = kernels.map(k => {
95+
val id = k.toString.split(".").last
96+
k.blocked_hyper_parameters.map(h => id+"/"+h)
97+
}).reduceLeft(_++_)
98+
99+
override def setHyperParameters(h: Map[String, Double]): DecomposableCovariance.this.type = {
100+
//group the hyper params by kernel id
101+
h.toSeq.map(kv => {
102+
val idS = kv._1.split("/")
103+
(idS.head, (idS.last, kv._2))
104+
}).groupBy(_._1).map(hypC => {
105+
val kid = hypC._1
106+
val hyper_params = hypC._2.map(_._2).toMap
107+
kernelMap(kid).setHyperParameters(hyper_params)
108+
})
109+
this
110+
}
111+
112+
override def evaluate(x: S, y: S): Double = {
113+
val (xs, ys) = (encoding*encoding)((x,y))
114+
reducer(xs.zip(ys).zip(kernels).map(coupleAndKern => {
115+
val (u,v) = coupleAndKern._1
116+
coupleAndKern._2.evaluate(u,v)
117+
}))
118+
}
119+
}

0 commit comments

Comments
 (0)