Skip to content

Commit 042fa97

Browse files
committed
Version 0.12: Added Batch SGD Implementation
1 parent a4bbeb8 commit 042fa97

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<modelVersion>4.0.0</modelVersion>
33
<groupId>org.kuleuven.esat</groupId>
44
<artifactId>bayesLearn</artifactId>
5-
<version>0.11</version>
5+
<version>0.12</version>
66
<inceptionYear>2008</inceptionYear>
77
<properties>
88
<scala.version>2.10.4</scala.version>

src/main/scala/org/kuleuven/esat/bayesLearn.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object bayesLearn extends App {
5151
" \n \\/__/ \\/__/ \\/__/ \\|__| "+
5252
"\\/__/ ")
5353

54-
echo("\nWelcome to Bayes Learn v 0.11\nInteractive Scala shell")
54+
echo("\nWelcome to Bayes Learn v 0.12\nInteractive Scala shell")
5555
echo("STADIUS ESAT KU Leuven (2015)\n")
5656
}
5757
}

src/main/scala/org/kuleuven/esat/graphicalModels/GaussianLinearModel.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ private[graphicalModels] class GaussianLinearModel(
4444
this
4545
}
4646

47+
def setBatchFraction(f: Double): this.type = {
48+
assert(f >= 0.0 && f <= 1.0, "Mini-Batch Fraction should be between 0.0 and 1.0")
49+
this.optimizer.setMiniBatchFraction(f)
50+
this
51+
}
52+
4753
def setRegParam(reg: Double): this.type = {
4854
this.optimizer.setRegParam(reg)
4955
this

src/main/scala/org/kuleuven/esat/optimization/GradientDescent.scala

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class GradientDescent (private var gradient: Gradient, private var updater: Upda
120120
this.updater,
121121
this.gradient,
122122
this.stepSize,
123+
initialP,
124+
ParamOutEdges,
125+
xy,
123126
this.miniBatchFraction
124127
)
125128
}
@@ -144,18 +147,19 @@ object GradientDescent {
144147
var count = 1
145148
var oldW: DenseVector[Double] = initial
146149
var newW = oldW
150+
val cumGradient: DenseVector[Double] = DenseVector.zeros(initial.length)
147151
logger.log(Priority.INFO, "Training model using SGD")
148152
while(count <= numIterations) {
149153
val targets = POutEdges.iterator()
150154
while (targets.hasNext) {
151155
val (x, y) = xy(targets.next())
152-
val (grad, _): (DenseVector[Double], Double) = gradient.compute(x, y, oldW)
153-
newW = updater.compute(oldW, grad, stepSize, count, regParam)._1
154-
oldW = newW
156+
gradient.compute(x, y, oldW, cumGradient)
155157
}
158+
newW = updater.compute(oldW, cumGradient,
159+
stepSize, count, regParam)._1
160+
oldW = newW
156161
count += 1
157162
}
158-
159163
newW
160164
}
161165

@@ -167,8 +171,29 @@ object GradientDescent {
167171
updater: Updater,
168172
gradient: Gradient,
169173
stepSize: Double,
174+
initial: DenseVector[Double],
175+
POutEdges: java.lang.Iterable[Edge],
176+
xy: (Edge) => (DenseVector[Double], Double),
170177
miniBatchFraction: Double): DenseVector[Double] = {
171-
DenseVector.zeros[Double](10)
178+
var count = 1
179+
var oldW: DenseVector[Double] = initial
180+
var newW = oldW
181+
val cumGradient: DenseVector[Double] = DenseVector.zeros(initial.length)
182+
logger.log(Priority.INFO, "Training model using SGD")
183+
while(count <= numIterations) {
184+
val targets = POutEdges.iterator()
185+
while (targets.hasNext) {
186+
if(scala.util.Random.nextDouble() <= miniBatchFraction) {
187+
val (x, y) = xy(targets.next())
188+
gradient.compute(x, y, oldW, cumGradient)
189+
}
190+
}
191+
newW = updater.compute(oldW, cumGradient,
192+
stepSize, count, regParam)._1
193+
oldW = newW
194+
count += 1
195+
}
196+
newW
172197
}
173198

174199
}

0 commit comments

Comments
 (0)