Skip to content

Commit 03336c8

Browse files
committed
add Char CNN example.
1 parent 771a828 commit 03336c8

File tree

5 files changed

+166
-26
lines changed

5 files changed

+166
-26
lines changed

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class CnnTextClassification : IExample
4040
float loss_value = 0;
4141
double max_accuracy = 0;
4242

43+
int alphabet_size = -1;
4344
int vocabulary_size = -1;
4445
NDArray train_x, valid_x, train_y, valid_y;
4546

@@ -117,10 +118,18 @@ public void PrepareData()
117118
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
118119

119120
Console.WriteLine("Building dataset...");
121+
var (x, y) = (new int[0][], new int[0]);
120122

121-
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
122-
vocabulary_size = len(word_dict);
123-
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
123+
if(ModelName == "char_cnn")
124+
{
125+
(x, y, alphabet_size) = DataHelpers.build_char_dataset(TRAIN_PATH, "char_cnn", CHAR_MAX_LEN);
126+
}
127+
else
128+
{
129+
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
130+
vocabulary_size = len(word_dict);
131+
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
132+
}
124133

125134
Console.WriteLine("\tDONE ");
126135

@@ -162,6 +171,9 @@ public Graph BuildGraph()
162171
case "word_cnn":
163172
textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS);
164173
break;
174+
case "char_cnn":
175+
textModel = new CharCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
176+
break;
165177
}
166178

167179
return graph;

test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ public static (int[][], int[]) build_word_dataset(string path, Dictionary<string
5555

5656
public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true)
5757
{
58-
if (model != "vd_cnn")
59-
throw new NotImplementedException(model);
6058
string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
6159
/*if (step == "train")
6260
df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow;
6+
using static Tensorflow.Python;
7+
8+
namespace TensorFlowNET.Examples.Text
9+
{
10+
public class CharCnn : ITextModel
11+
{
12+
public CharCnn(int alphabet_size, int document_max_len, int num_class)
13+
{
14+
var learning_rate = 0.001f;
15+
var filter_sizes = new int[] { 7, 7, 3, 3, 3, 3 };
16+
var num_filters = 256;
17+
var kernel_initializer = tf.truncated_normal_initializer(stddev: 0.05f);
18+
19+
var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
20+
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
21+
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
22+
var global_step = tf.Variable(0, trainable: false);
23+
var keep_prob = tf.where(is_training, 0.5f, 1.0f);
24+
25+
var x_one_hot = tf.one_hot(x, alphabet_size);
26+
var x_expanded = tf.expand_dims(x_one_hot, -1);
27+
28+
// ============= Convolutional Layers =============
29+
Tensor pool1 = null, pool2 = null;
30+
Tensor conv3 = null, conv4 = null, conv5 = null, conv6 = null;
31+
Tensor h_pool = null;
32+
33+
with(tf.name_scope("conv-maxpool-1"), delegate
34+
{
35+
var conv1 = tf.layers.conv2d(x_expanded,
36+
filters: num_filters,
37+
kernel_size: new[] { filter_sizes[0], alphabet_size },
38+
kernel_initializer: kernel_initializer,
39+
activation: tf.nn.relu());
40+
41+
pool1 = tf.layers.max_pooling2d(conv1,
42+
pool_size: new[] { 3, 1 },
43+
strides: new[] { 3, 1 });
44+
pool1 = tf.transpose(pool1, new[] { 0, 1, 3, 2 });
45+
});
46+
47+
with(tf.name_scope("conv-maxpool-2"), delegate
48+
{
49+
var conv2 = tf.layers.conv2d(pool1,
50+
filters: num_filters,
51+
kernel_size: new[] {filter_sizes[1], num_filters },
52+
kernel_initializer: kernel_initializer,
53+
activation: tf.nn.relu());
54+
55+
pool2 = tf.layers.max_pooling2d(conv2,
56+
pool_size: new[] { 3, 1 },
57+
strides: new[] { 3, 1 });
58+
pool2 = tf.transpose(pool2, new[] { 0, 1, 3, 2 });
59+
});
60+
61+
with(tf.name_scope("conv-3"), delegate
62+
{
63+
conv3 = tf.layers.conv2d(pool2,
64+
filters: num_filters,
65+
kernel_size: new[] { filter_sizes[2], num_filters },
66+
kernel_initializer: kernel_initializer,
67+
activation: tf.nn.relu());
68+
conv3 = tf.transpose(conv3, new[] { 0, 1, 3, 2 });
69+
});
70+
71+
with(tf.name_scope("conv-4"), delegate
72+
{
73+
conv4 = tf.layers.conv2d(conv3,
74+
filters: num_filters,
75+
kernel_size: new[] { filter_sizes[3], num_filters },
76+
kernel_initializer: kernel_initializer,
77+
activation: tf.nn.relu());
78+
conv4 = tf.transpose(conv4, new[] { 0, 1, 3, 2 });
79+
});
80+
81+
with(tf.name_scope("conv-5"), delegate
82+
{
83+
conv5 = tf.layers.conv2d(conv4,
84+
filters: num_filters,
85+
kernel_size: new[] { filter_sizes[4], num_filters },
86+
kernel_initializer: kernel_initializer,
87+
activation: tf.nn.relu());
88+
conv5 = tf.transpose(conv5, new[] { 0, 1, 3, 2 });
89+
});
90+
91+
with(tf.name_scope("conv-maxpool-6"), delegate
92+
{
93+
conv6 = tf.layers.conv2d(conv5,
94+
filters: num_filters,
95+
kernel_size: new[] { filter_sizes[5], num_filters },
96+
kernel_initializer: kernel_initializer,
97+
activation: tf.nn.relu());
98+
99+
var pool6 = tf.layers.max_pooling2d(conv6,
100+
pool_size: new[] { 3, 1 },
101+
strides: new[] { 3, 1 });
102+
pool6 = tf.transpose(pool6, new[] { 0, 2, 1, 3 });
103+
104+
h_pool = tf.reshape(pool6, new[] { -1, 34 * num_filters });
105+
});
106+
107+
// ============= Fully Connected Layers =============
108+
Tensor fc1_out = null, fc2_out = null;
109+
Tensor logits = null;
110+
Tensor predictions = null;
111+
112+
with(tf.name_scope("fc-1"), delegate
113+
{
114+
fc1_out = tf.layers.dense(h_pool,
115+
1024,
116+
activation: tf.nn.relu(),
117+
kernel_initializer: kernel_initializer);
118+
});
119+
120+
with(tf.name_scope("fc-2"), delegate
121+
{
122+
fc2_out = tf.layers.dense(fc1_out,
123+
1024,
124+
activation: tf.nn.relu(),
125+
kernel_initializer: kernel_initializer);
126+
});
127+
128+
with(tf.name_scope("fc-3"), delegate
129+
{
130+
logits = tf.layers.dense(fc2_out,
131+
num_class,
132+
kernel_initializer: kernel_initializer);
133+
predictions = tf.argmax(logits, -1, output_type: tf.int32);
134+
});
135+
136+
with(tf.name_scope("loss"), delegate
137+
{
138+
var y_one_hot = tf.one_hot(y, num_class);
139+
var loss = tf.reduce_mean(
140+
tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));
141+
var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step);
142+
});
143+
144+
with(tf.name_scope("accuracy"), delegate
145+
{
146+
var correct_predictions = tf.equal(predictions, y);
147+
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name: "accuracy");
148+
});
149+
}
150+
}
151+
}

test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,5 @@ namespace TensorFlowNET.Examples.Text
77
{
88
interface ITextModel
99
{
10-
Tensor is_training { get; }
11-
Tensor x { get;}
12-
Tensor y { get; }
1310
}
1411
}

test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,6 @@ namespace TensorFlowNET.Examples.Text
99
{
1010
public class WordCnn : ITextModel
1111
{
12-
private int embedding_size;
13-
private int[] filter_sizes;
14-
private int[] num_filters;
15-
private int[] num_blocks;
16-
private float learning_rate;
17-
private IInitializer cnn_initializer;
18-
private IInitializer fc_initializer;
19-
public Tensor x { get; private set; }
20-
public Tensor y { get; private set; }
21-
public Tensor is_training { get; private set; }
22-
private RefVariable global_step;
23-
private RefVariable embeddings;
24-
private Tensor x_emb;
25-
private Tensor x_expanded;
26-
private Tensor logits;
27-
private Tensor predictions;
28-
private Tensor loss;
29-
3012
public WordCnn(int vocabulary_size, int document_max_len, int num_class)
3113
{
3214
var embedding_size = 128;

0 commit comments

Comments
 (0)