@@ -106,11 +106,49 @@ object SVMKernel {
106106 }
107107 }
108108
109+ def buildKernelGradMatrix [S <: Seq [T ], T ](
110+ data1 : S ,
111+ hyper_parameters : Seq [String ],
112+ eval : (T , T ) => Double ,
113+ evalGrad : (String ) => (T , T ) => Double ):
114+ Map [String , DenseMatrix [Double ]] = {
115+
116+ val (rows, cols) = (data1.length, data1.length)
117+ logger.info(" Constructing Kernel/Grad Matrices" )
118+ logger.info(" Dimensions: " + rows + " x " + cols)
119+
120+ val keys = Seq (" kernel-matrix" ) ++ hyper_parameters
121+
122+ optimize {
123+ utils.combine(Seq (data1.zipWithIndex, data1.zipWithIndex))
124+ .filter(s => s.head._2 >= s.last._2)
125+ .flatMap(s => {
126+ keys.map(k =>
127+ if (k == " kernel-matrix" ) (k, ((s.head._2, s.last._2), eval(s.head._1, s.last._1)))
128+ else (k, ((s.head._2, s.last._2), evalGrad(k)(s.head._1, s.last._1))))
129+ }).groupBy(_._1).map(cl => {
130+
131+ if (cl._1 == " kernel-matrix" ) logger.info(" Constructing Kernel Matrix" )
132+ else logger.info(" Constructing Grad Matrix for: " + cl._1)
133+
134+ val kernelIndex = cl._2.map(_._2).toMap
135+
136+ (
137+ cl._1,
138+ DenseMatrix .tabulate[Double ](rows, cols){
139+ (i, j) => if (i >= j) kernelIndex((i,j)) else kernelIndex((j,i))
140+ }
141+ )
142+ })
143+ }
144+ }
145+
146+
109147 /**
110148 * Returns the kernel matrix along with
111149 * its derivatives for each hyper-parameter.
112150 * */
113- def buildKernelGradMatrix [S <: Seq [T ], T ](
151+ def buildCrossKernelGradMatrix [S <: Seq [T ], T ](
114152 data1 : S , data2 : S ,
115153 hyper_parameters : Seq [String ],
116154 eval : (T , T ) => Double ,
@@ -125,7 +163,6 @@ object SVMKernel {
125163
126164 optimize {
127165 utils.combine(Seq (data1.zipWithIndex, data2.zipWithIndex))
128- .filter(s => s.head._2 >= s.last._2)
129166 .flatMap(s => {
130167 keys.map(k =>
131168 if (k == " kernel-matrix" ) (k, ((s.head._2, s.last._2), eval(s.head._1, s.last._1)))
@@ -140,7 +177,7 @@ object SVMKernel {
140177 (
141178 cl._1,
142179 DenseMatrix .tabulate[Double ](rows, cols){
143- (i, j) => if (i >= j) kernelIndex((i,j)) else kernelIndex((j,i ))
180+ (i, j) => kernelIndex((i,j))
144181 }
145182 )
146183 })
@@ -249,13 +286,23 @@ object SVMKernel {
249286 print(" \n " )
250287 logger.info(" :- Partition: " + partitionIndex)
251288
252- SVMKernel .buildKernelGradMatrix(
253- c.head._1, c.last._1,
254- hyper_parameters,
255- eval, evalGrad).map(cluster => {
289+ if (partitionIndex._1 == partitionIndex._2) {
290+ SVMKernel .buildKernelGradMatrix(
291+ c.head._1,
292+ hyper_parameters,
293+ eval, evalGrad).map(cluster => {
294+ (cluster._1, (partitionIndex, cluster._2))
295+ }).toSeq
296+
297+ } else {
298+ SVMKernel .buildCrossKernelGradMatrix(
299+ c.head._1, c.last._1,
300+ hyper_parameters,
301+ eval, evalGrad).map(cluster => {
256302 (cluster._1, (partitionIndex, cluster._2))
257303 }).toSeq
258304
305+ }
259306 }).groupBy(_._1).map(cluster => {
260307
261308 val hyp = cluster._1
0 commit comments