Skip to content

Commit 771a828

Browse files
committed
move word cnn to seperate class.
1 parent f35b146 commit 771a828

File tree

7 files changed

+138
-414
lines changed

7 files changed

+138
-414
lines changed

docs/source/NeuralNetwork.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Neural Network
1+
# Chapter. Neural Network
22

33
In this chapter, we'll learn how to build a graph of neural network model. The key advantage of neural network compared to Linear Classifier is that it can separate data which it not linearly separable. We'll implement this model to classify hand-written digits images from the MNIST dataset.
44

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 25 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using NumSharp;
1010
using Tensorflow;
1111
using Tensorflow.Sessions;
12+
using TensorFlowNET.Examples.Text;
1213
using TensorFlowNET.Examples.Utility;
1314
using static Tensorflow.Python;
1415

@@ -24,24 +25,27 @@ public class CnnTextClassification : IExample
2425
public int? DataLimit = null;
2526
public bool IsImportingGraph { get; set; } = false;
2627

27-
private const string dataDir = "word_cnn";
28-
private string dataFileName = "dbpedia_csv.tar.gz";
28+
const string dataDir = "cnn_text";
29+
string dataFileName = "dbpedia_csv.tar.gz";
2930

30-
private const string TRAIN_PATH = "word_cnn/dbpedia_csv/train.csv";
31-
private const string TEST_PATH = "word_cnn/dbpedia_csv/test.csv";
31+
string TRAIN_PATH = $"{dataDir}/dbpedia_csv/train.csv";
32+
string TEST_PATH = $"{dataDir}/dbpedia_csv/test.csv";
3233

33-
private const int NUM_CLASS = 14;
34-
private const int BATCH_SIZE = 64;
35-
private const int NUM_EPOCHS = 10;
36-
private const int WORD_MAX_LEN = 100;
37-
private const int CHAR_MAX_LEN = 1014;
34+
int NUM_CLASS = 14;
35+
int BATCH_SIZE = 64;
36+
int NUM_EPOCHS = 10;
37+
int WORD_MAX_LEN = 100;
38+
int CHAR_MAX_LEN = 1014;
3839

39-
protected float loss_value = 0;
40+
float loss_value = 0;
4041
double max_accuracy = 0;
4142

42-
int vocabulary_size = 50000;
43+
int vocabulary_size = -1;
4344
NDArray train_x, valid_x, train_y, valid_y;
4445

46+
ITextModel textModel;
47+
public string ModelName = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
48+
4549
public bool Run()
4650
{
4751
PrepareData();
@@ -68,7 +72,7 @@ public bool Run()
6872
return (train_x, valid_x, train_y, valid_y);
6973
}
7074

71-
private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
75+
private void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
7276
{
7377
int i = 0;
7478
var label_keys = labels.Keys.ToArray();
@@ -114,10 +118,8 @@ public void PrepareData()
114118

115119
Console.WriteLine("Building dataset...");
116120

117-
int alphabet_size = 0;
118-
119121
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
120-
//vocabulary_size = len(word_dict);
122+
vocabulary_size = len(word_dict);
121123
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
122124

123125
Console.WriteLine("\tDONE ");
@@ -155,83 +157,19 @@ public Graph BuildGraph()
155157
{
156158
var graph = tf.Graph().as_default();
157159

158-
var embedding_size = 128;
159-
var learning_rate = 0.001f;
160-
var filter_sizes = new int[3, 4, 5];
161-
var num_filters = 100;
162-
var document_max_len = 100;
163-
164-
var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
165-
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
166-
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
167-
var global_step = tf.Variable(0, trainable: false);
168-
var keep_prob = tf.where(is_training, 0.5f, 1.0f);
169-
Tensor x_emb = null;
170-
171-
with(tf.name_scope("embedding"), scope =>
172-
{
173-
var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
174-
var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
175-
x_emb = tf.nn.embedding_lookup(embeddings, x);
176-
x_emb = tf.expand_dims(x_emb, -1);
177-
});
178-
179-
var pooled_outputs = new List<Tensor>();
180-
for (int len = 0; len < filter_sizes.Rank; len++)
160+
switch (ModelName)
181161
{
182-
int filter_size = filter_sizes.GetLength(len);
183-
var conv = tf.layers.conv2d(
184-
x_emb,
185-
filters: num_filters,
186-
kernel_size: new int[] { filter_size, embedding_size },
187-
strides: new int[] { 1, 1 },
188-
padding: "VALID",
189-
activation: tf.nn.relu());
190-
191-
var pool = tf.layers.max_pooling2d(
192-
conv,
193-
pool_size: new[] { document_max_len - filter_size + 1, 1 },
194-
strides: new[] { 1, 1 },
195-
padding: "VALID");
196-
197-
pooled_outputs.Add(pool);
162+
case "word_cnn":
163+
textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS);
164+
break;
198165
}
199166

200-
var h_pool = tf.concat(pooled_outputs, 3);
201-
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));
202-
Tensor h_drop = null;
203-
with(tf.name_scope("dropout"), delegate
204-
{
205-
h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
206-
});
207-
208-
Tensor logits = null;
209-
Tensor predictions = null;
210-
with(tf.name_scope("output"), delegate
211-
{
212-
logits = tf.layers.dense(h_drop, NUM_CLASS);
213-
predictions = tf.argmax(logits, -1, output_type: tf.int32);
214-
});
215-
216-
with(tf.name_scope("loss"), delegate
217-
{
218-
var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
219-
var loss = tf.reduce_mean(sscel);
220-
var adam = tf.train.AdamOptimizer(learning_rate);
221-
var optimizer = adam.minimize(loss, global_step: global_step);
222-
});
223-
224-
with(tf.name_scope("accuracy"), delegate
225-
{
226-
var correct_predictions = tf.equal(predictions, y);
227-
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
228-
});
229-
230167
return graph;
231168
}
232169

233-
private bool Train(Session sess, Graph graph)
170+
public void Train(Session sess)
234171
{
172+
var graph = tf.get_default_graph();
235173
var stopwatch = Stopwatch.StartNew();
236174

237175
sess.run(tf.global_variables_initializer());
@@ -263,10 +201,7 @@ private bool Train(Session sess, Graph graph)
263201
loss_value = result[2];
264202
var step = (int)result[1];
265203
if (step % 10 == 0)
266-
{
267-
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
268-
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
269-
}
204+
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value.ToString("0.0000")}.");
270205

271206
if (step % 100 == 0)
272207
{
@@ -289,7 +224,7 @@ private bool Train(Session sess, Graph graph)
289224

290225
var valid_accuracy = sum_accuracy / cnt;
291226

292-
print($"\nValidation Accuracy = {valid_accuracy}\n");
227+
print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n");
293228

294229
// Save model
295230
if (valid_accuracy > max_accuracy)
@@ -300,13 +235,6 @@ private bool Train(Session sess, Graph graph)
300235
}
301236
}
302237
}
303-
304-
return max_accuracy > 0.9;
305-
}
306-
307-
public void Train(Session sess)
308-
{
309-
Train(sess, sess.graph);
310238
}
311239

312240
public void Predict(Session sess)

0 commit comments

Comments
 (0)