Skip to content

Commit ee0b935

Browse files
committed
Add BasicRNNCell, LayerRNNCell, RNNCell,
Change nest.flatten to generic.
1 parent e72fd53 commit ee0b935

File tree

10 files changed

+183
-18
lines changed

10 files changed

+183
-18
lines changed

docs/source/ConvolutionNeuralNetwork.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,4 +346,5 @@ Get started with the implementation:
346346
}
347347
```
348348

349-
![cnn-reuslt](../assets/cnn-result.png)
349+
![](../assets/cnn-result.png)
350+

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ limitations under the License.
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Linq;
1920
using System.Text;
2021
using Tensorflow.Operations;
2122
using Tensorflow.Operations.Activation;
23+
using Tensorflow.Util;
2224
using static Tensorflow.Python;
2325

2426
namespace Tensorflow
@@ -68,6 +70,33 @@ public static Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_sha
6870
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
6971
}
7072

73+
/// <summary>
74+
/// Creates a recurrent neural network specified by RNNCell `cell`.
75+
/// </summary>
76+
/// <param name="cell">An instance of RNNCell.</param>
77+
/// <param name="inputs">The RNN inputs.</param>
78+
/// <param name="dtype"></param>
79+
/// <param name="swap_memory"></param>
80+
/// <param name="time_major"></param>
81+
/// <returns>A pair (outputs, state)</returns>
82+
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid,
83+
bool swap_memory = false, bool time_major = false)
84+
{
85+
with(variable_scope("rnn"), scope =>
86+
{
87+
VariableScope varscope = scope;
88+
var flat_input = nest.flatten(inputs);
89+
90+
if (!time_major)
91+
{
92+
flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
93+
//flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
94+
}
95+
});
96+
97+
throw new NotImplementedException("");
98+
}
99+
71100
public static (Tensor, Tensor) moments(Tensor x,
72101
int[] axes,
73102
string name = null,
Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Text;
20+
using Tensorflow.Keras.Engine;
21+
using Tensorflow.Operations.Activation;
422

523
namespace Tensorflow
624
{
7-
public class BasicRNNCell
25+
public class BasicRNNCell : LayerRNNCell
826
{
27+
int _num_units;
28+
Func<Tensor, string, Tensor> _activation;
29+
30+
public BasicRNNCell(int num_units,
31+
Func<Tensor, string, Tensor> activation = null,
32+
bool? reuse = null,
33+
string name = null,
34+
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse,
35+
name: name,
36+
dtype: dtype)
37+
{
38+
// Inputs must be 2-dimensional.
39+
input_spec = new InputSpec(ndim: 2);
40+
41+
_num_units = num_units;
42+
if (activation == null)
43+
_activation = math_ops.tanh;
44+
else
45+
_activation = activation;
46+
}
947
}
1048
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Text;
20+
21+
namespace Tensorflow
22+
{
23+
public class LayerRNNCell : RNNCell
24+
{
25+
public LayerRNNCell(bool? _reuse = null,
26+
string name = null,
27+
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse,
28+
name: name,
29+
dtype: dtype)
30+
{
31+
}
32+
}
33+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Text;
20+
21+
namespace Tensorflow
22+
{
23+
/// <summary>
24+
/// Abstract object representing an RNN cell.
25+
///
26+
/// Every `RNNCell` must have the properties below and implement `call` with
27+
/// the signature `(output, next_state) = call(input, state)`. The optional
28+
/// third input argument, `scope`, is allowed for backwards compatibility
29+
/// purposes; but should be left off for new subclasses.
30+
///
31+
/// This definition of cell differs from the definition used in the literature.
32+
/// In the literature, 'cell' refers to an object with a single scalar output.
33+
/// This definition refers to a horizontal array of such units.
34+
///
35+
/// An RNN cell, in the most abstract setting, is anything that has
36+
/// a state and performs some operation that takes a matrix of inputs.
37+
/// This operation results in an output matrix with `self.output_size` columns.
38+
/// If `self.state_size` is an integer, this operation also results in a new
39+
/// state matrix with `self.state_size` columns. If `self.state_size` is a
40+
/// (possibly nested tuple of) TensorShape object(s), then it should return a
41+
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
42+
/// for each `s` in `self.batch_size`.
43+
/// </summary>
44+
public abstract class RNNCell : Layers.Layer
45+
{
46+
/// <summary>
47+
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
48+
/// difference between TF and Keras RNN cell.
49+
/// </summary>
50+
protected bool _is_tf_rnn_cell = false;
51+
52+
public RNNCell(bool trainable = true,
53+
string name = null,
54+
TF_DataType dtype = TF_DataType.DtInvalid,
55+
bool? _reuse = null) : base(trainable: trainable,
56+
name: name,
57+
dtype: dtype,
58+
_reuse: _reuse)
59+
{
60+
_is_tf_rnn_cell = true;
61+
}
62+
}
63+
}

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ public static Tensor conj(Tensor x, string name = null)
551551
});
552552
}
553553

554+
public static Tensor tanh(Tensor x, string name = null)
555+
=> gen_math_ops.tanh(x, name);
556+
554557
public static Tensor truediv(Tensor x, Tensor y, string name = null)
555558
=> _truediv_python3(x, y, name);
556559

src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ namespace Tensorflow.Operations
77
public class rnn_cell_impl
88
{
99
public BasicRNNCell BasicRNNCell(int num_units)
10-
{
11-
throw new NotImplementedException();
12-
}
10+
=> new BasicRNNCell(num_units);
1311
}
1412
}

src/TensorFlowNET.Core/Util/nest.py.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,14 @@ public static bool is_sequence(object arg)
214214
//# See the swig file (util.i) for documentation.
215215
//flatten = _pywrap_tensorflow.Flatten
216216

217-
public static List<object> flatten(object structure)
217+
public static List<T> flatten<T>(T structure)
218218
{
219-
var list = new List<object>();
219+
var list = new List<T>();
220220
_flatten_recursive(structure, list);
221221
return list;
222222
}
223223

224-
private static void _flatten_recursive(object obj, List<object> list)
224+
private static void _flatten_recursive<T>(T obj, List<T> list)
225225
{
226226
if (obj is string)
227227
{
@@ -232,7 +232,7 @@ private static void _flatten_recursive(object obj, List<object> list)
232232
{
233233
var dict = obj as IDictionary;
234234
foreach (var key in _sorted(dict))
235-
_flatten_recursive(dict[key], list);
235+
_flatten_recursive((T)dict[key], list);
236236
return;
237237
}
238238
if (obj is NDArray)
@@ -244,7 +244,7 @@ private static void _flatten_recursive(object obj, List<object> list)
244244
{
245245
var structure = obj as IEnumerable;
246246
foreach (var child in structure)
247-
_flatten_recursive(child, list);
247+
_flatten_recursive((T)child, list);
248248
return;
249249
}
250250
list.Add(obj);

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@ limitations under the License.
2525
namespace TensorFlowNET.Examples.ImageProcess
2626
{
2727
/// <summary>
28-
/// Convolutional Neural Network classifier for Hand Written Digits
29-
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end.
30-
/// Use Stochastic Gradient Descent (SGD) optimizer.
31-
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1
28+
/// Recurrent Neural Network for handwritten digits MNIST.
29+
/// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2
3230
/// </summary>
3331
public class DigitRecognitionRNN : IExample
3432
{
35-
public bool Enabled { get; set; } = false;
33+
public bool Enabled { get; set; } = true;
3634
public bool IsImportingGraph { get; set; } = false;
3735

3836
public string Name => "MNIST RNN";
@@ -84,6 +82,7 @@ public Graph BuildGraph()
8482
var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs });
8583
var y = tf.placeholder(tf.int32, new[] { -1 });
8684
var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons);
85+
var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32);
8786

8887
return graph;
8988
}
@@ -154,6 +153,7 @@ public void PrepareData()
154153
print("Size of:");
155154
print($"- Training-set:\t\t{len(mnist.train.data)}");
156155
print($"- Validation-set:\t{len(mnist.validation.data)}");
156+
print($"- Test-set:\t\t{len(mnist.test.data)}");
157157
}
158158

159159
public Graph ImportGraph() => throw new NotImplementedException();

test/TensorFlowNET.UnitTest/nest_test/NestTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ public void testFlattenAndPack()
7878
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
7979

8080
self.assertEqual(new List<object> { 5 }, nest.flatten(5));
81-
flat = nest.flatten(np.array(new[] { 5 }));
82-
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat);
81+
var flat1 = nest.flatten(np.array(new[] { 5 }));
82+
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1);
8383

8484
self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
8585
self.assertEqual(np.array(new[] { 5 }),

0 commit comments

Comments
 (0)