Skip to content

Commit 0f98db9

Browse files
Merge pull request #295 from BrainJS/294-hidden-size-fix
fix: Resolve issue with different size hidden layers for recurrent ne…
2 parents 71864ac + 772eb8c commit 0f98db9

File tree

7 files changed

+50
-27
lines changed

7 files changed

+50
-27
lines changed

browser.js

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* license: MIT (http://opensource.org/licenses/MIT)
77
* author: Heather Arthur <[email protected]>
88
* homepage: https://github.com/brainjs/brain.js#readme
9-
* version: 1.4.3
9+
* version: 1.4.4
1010
*
1111
* acorn:
1212
* license: MIT (http://opensource.org/licenses/MIT)
@@ -4236,13 +4236,11 @@ var RNN = function () {
42364236
}, {
42374237
key: 'mapModel',
42384238
value: function mapModel() {
4239-
var _this = this;
4240-
42414239
var model = this.model;
42424240
var hiddenLayers = model.hiddenLayers;
42434241
var allMatrices = model.allMatrices;
42444242
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
4245-
return new _matrix2.default(_this.hiddenLayers[0], 1);
4243+
return new _matrix2.default(size, 1);
42464244
});
42474245

42484246
this.createInputMatrix();
@@ -4572,8 +4570,6 @@ var RNN = function () {
45724570
}, {
45734571
key: 'fromJSON',
45744572
value: function fromJSON(json) {
4575-
var _this2 = this;
4576-
45774573
var defaults = this.constructor.defaults;
45784574
var options = json.options;
45794575
this.model = null;
@@ -4619,7 +4615,7 @@ var RNN = function () {
46194615
equationConnections: []
46204616
};
46214617
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
4622-
return new _matrix2.default(_this2.hiddenLayers[0], 1);
4618+
return new _matrix2.default(size, 1);
46234619
});
46244620
this.bindEquation();
46254621
}

browser.min.js

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/recurrent/rnn.js

Lines changed: 2 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/recurrent/rnn.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "brain.js",
33
"description": "Neural network library",
4-
"version": "1.4.3",
4+
"version": "1.4.4",
55
"author": "Heather Arthur <[email protected]>",
66
"repository": {
77
"type": "git",

src/recurrent/rnn.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ export default class RNN {
152152
let model = this.model;
153153
let hiddenLayers = model.hiddenLayers;
154154
let allMatrices = model.allMatrices;
155-
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
155+
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));
156156

157157
this.createInputMatrix();
158158
if (!model.input) throw new Error('net.model.input not set');
@@ -506,7 +506,7 @@ export default class RNN {
506506
equations: [],
507507
equationConnections: [],
508508
};
509-
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
509+
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));
510510
this.bindEquation();
511511
}
512512

test/recurrent/rnn.js

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ describe('rnn', () => {
2020
net.initialize();
2121
assert.notEqual(net.model, null);
2222
});
23+
it('can setup different size hiddenLayers', () => {
24+
const inputSize = 2;
25+
const hiddenLayers = [5,4,3];
26+
const networkOptions = {
27+
learningRate: 0.001,
28+
decayRate: 0.75,
29+
inputSize: inputSize,
30+
hiddenLayers,
31+
outputSize: inputSize
32+
};
33+
34+
const net = new RNN(networkOptions);
35+
net.initialize();
36+
net.bindEquation();
37+
assert.equal(net.model.hiddenLayers.length, 3);
38+
assert.equal(net.model.hiddenLayers[0].weight.columns, inputSize);
39+
assert.equal(net.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
40+
assert.equal(net.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
41+
assert.equal(net.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
42+
assert.equal(net.model.hiddenLayers[2].weight.columns, hiddenLayers[1]);
43+
assert.equal(net.model.hiddenLayers[2].weight.rows, hiddenLayers[2]);
44+
});
2345
});
2446
describe('basic operations', () => {
2547
it('starts with zeros in input.deltas', () => {
@@ -354,9 +376,12 @@ describe('rnn', () => {
354376

355377
describe('.fromJSON', () => {
356378
it('can import model from json', () => {
357-
let dataFormatter = new DataFormatter('abcdef'.split(''));
358-
let jsonString = JSON.stringify(new RNN({
359-
inputSize: 6, //<- length
379+
const inputSize = 6;
380+
const hiddenLayers = [10, 20];
381+
const dataFormatter = new DataFormatter('abcdef'.split(''));
382+
const jsonString = JSON.stringify(new RNN({
383+
inputSize, //<- length
384+
hiddenLayers,
360385
inputRange: dataFormatter.characters.length,
361386
outputSize: dataFormatter.characters.length //<- length
362387
}).toJSON(), null, 2);
@@ -368,6 +393,12 @@ describe('rnn', () => {
368393
assert.equal(clone.inputSize, 6);
369394
assert.equal(clone.inputRange, dataFormatter.characters.length);
370395
assert.equal(clone.outputSize, dataFormatter.characters.length);
396+
397+
assert.equal(clone.model.hiddenLayers.length, 2);
398+
assert.equal(clone.model.hiddenLayers[0].weight.columns, inputSize);
399+
assert.equal(clone.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
400+
assert.equal(clone.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
401+
assert.equal(clone.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
371402
});
372403

373404
it('can import model from json using .fromJSON()', () => {

0 commit comments

Comments
 (0)