9
9
using NumSharp ;
10
10
using Tensorflow ;
11
11
using Tensorflow . Sessions ;
12
+ using TensorFlowNET . Examples . Text ;
12
13
using TensorFlowNET . Examples . Utility ;
13
14
using static Tensorflow . Python ;
14
15
@@ -24,24 +25,27 @@ public class CnnTextClassification : IExample
24
25
public int ? DataLimit = null ;
25
26
public bool IsImportingGraph { get ; set ; } = false ;
26
27
27
- private const string dataDir = "word_cnn " ;
28
- private string dataFileName = "dbpedia_csv.tar.gz" ;
28
+ const string dataDir = "cnn_text " ;
29
+ string dataFileName = "dbpedia_csv.tar.gz" ;
29
30
30
- private const string TRAIN_PATH = "word_cnn /dbpedia_csv/train.csv";
31
- private const string TEST_PATH = "word_cnn /dbpedia_csv/test.csv";
31
+ string TRAIN_PATH = $ " { dataDir } /dbpedia_csv/train.csv";
32
+ string TEST_PATH = $ " { dataDir } /dbpedia_csv/test.csv";
32
33
33
- private const int NUM_CLASS = 14 ;
34
- private const int BATCH_SIZE = 64 ;
35
- private const int NUM_EPOCHS = 10 ;
36
- private const int WORD_MAX_LEN = 100 ;
37
- private const int CHAR_MAX_LEN = 1014 ;
34
+ int NUM_CLASS = 14 ;
35
+ int BATCH_SIZE = 64 ;
36
+ int NUM_EPOCHS = 10 ;
37
+ int WORD_MAX_LEN = 100 ;
38
+ int CHAR_MAX_LEN = 1014 ;
38
39
39
- protected float loss_value = 0 ;
40
+ float loss_value = 0 ;
40
41
double max_accuracy = 0 ;
41
42
42
- int vocabulary_size = 50000 ;
43
+ int vocabulary_size = - 1 ;
43
44
NDArray train_x , valid_x , train_y , valid_y ;
44
45
46
+ ITextModel textModel ;
47
+ public string ModelName = "word_cnn" ; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
48
+
45
49
public bool Run ( )
46
50
{
47
51
PrepareData ( ) ;
@@ -68,7 +72,7 @@ public bool Run()
68
72
return ( train_x , valid_x , train_y , valid_y ) ;
69
73
}
70
74
71
- private static void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
75
+ private void FillWithShuffledLabels ( int [ ] [ ] x , int [ ] y , int [ ] [ ] shuffled_x , int [ ] shuffled_y , Random random , Dictionary < int , HashSet < int > > labels )
72
76
{
73
77
int i = 0 ;
74
78
var label_keys = labels . Keys . ToArray ( ) ;
@@ -114,10 +118,8 @@ public void PrepareData()
114
118
115
119
Console . WriteLine ( "Building dataset..." ) ;
116
120
117
- int alphabet_size = 0 ;
118
-
119
121
var word_dict = DataHelpers . build_word_dict ( TRAIN_PATH ) ;
120
- // vocabulary_size = len(word_dict);
122
+ vocabulary_size = len ( word_dict ) ;
121
123
var ( x , y ) = DataHelpers . build_word_dataset ( TRAIN_PATH , word_dict , WORD_MAX_LEN ) ;
122
124
123
125
Console . WriteLine ( "\t DONE " ) ;
@@ -155,83 +157,19 @@ public Graph BuildGraph()
155
157
{
156
158
var graph = tf . Graph ( ) . as_default ( ) ;
157
159
158
- var embedding_size = 128 ;
159
- var learning_rate = 0.001f ;
160
- var filter_sizes = new int [ 3 , 4 , 5 ] ;
161
- var num_filters = 100 ;
162
- var document_max_len = 100 ;
163
-
164
- var x = tf . placeholder ( tf . int32 , new TensorShape ( - 1 , document_max_len ) , name : "x" ) ;
165
- var y = tf . placeholder ( tf . int32 , new TensorShape ( - 1 ) , name : "y" ) ;
166
- var is_training = tf . placeholder ( tf . @bool , new TensorShape ( ) , name : "is_training" ) ;
167
- var global_step = tf . Variable ( 0 , trainable : false ) ;
168
- var keep_prob = tf . where ( is_training , 0.5f , 1.0f ) ;
169
- Tensor x_emb = null ;
170
-
171
- with ( tf . name_scope ( "embedding" ) , scope =>
172
- {
173
- var init_embeddings = tf . random_uniform ( new int [ ] { vocabulary_size , embedding_size } ) ;
174
- var embeddings = tf . get_variable ( "embeddings" , initializer : init_embeddings ) ;
175
- x_emb = tf . nn . embedding_lookup ( embeddings , x ) ;
176
- x_emb = tf . expand_dims ( x_emb , - 1 ) ;
177
- } ) ;
178
-
179
- var pooled_outputs = new List < Tensor > ( ) ;
180
- for ( int len = 0 ; len < filter_sizes . Rank ; len ++ )
160
+ switch ( ModelName )
181
161
{
182
- int filter_size = filter_sizes . GetLength ( len ) ;
183
- var conv = tf . layers . conv2d (
184
- x_emb ,
185
- filters : num_filters ,
186
- kernel_size : new int [ ] { filter_size , embedding_size } ,
187
- strides : new int [ ] { 1 , 1 } ,
188
- padding : "VALID" ,
189
- activation : tf . nn . relu ( ) ) ;
190
-
191
- var pool = tf . layers . max_pooling2d (
192
- conv ,
193
- pool_size : new [ ] { document_max_len - filter_size + 1 , 1 } ,
194
- strides : new [ ] { 1 , 1 } ,
195
- padding : "VALID" ) ;
196
-
197
- pooled_outputs . Add ( pool ) ;
162
+ case "word_cnn" :
163
+ textModel = new WordCnn ( vocabulary_size , WORD_MAX_LEN , NUM_CLASS ) ;
164
+ break ;
198
165
}
199
166
200
- var h_pool = tf . concat ( pooled_outputs , 3 ) ;
201
- var h_pool_flat = tf . reshape ( h_pool , new TensorShape ( - 1 , num_filters * filter_sizes . Rank ) ) ;
202
- Tensor h_drop = null ;
203
- with ( tf . name_scope ( "dropout" ) , delegate
204
- {
205
- h_drop = tf . nn . dropout ( h_pool_flat , keep_prob ) ;
206
- } ) ;
207
-
208
- Tensor logits = null ;
209
- Tensor predictions = null ;
210
- with ( tf . name_scope ( "output" ) , delegate
211
- {
212
- logits = tf . layers . dense ( h_drop , NUM_CLASS ) ;
213
- predictions = tf . argmax ( logits , - 1 , output_type : tf . int32 ) ;
214
- } ) ;
215
-
216
- with ( tf . name_scope ( "loss" ) , delegate
217
- {
218
- var sscel = tf . nn . sparse_softmax_cross_entropy_with_logits ( logits : logits , labels : y ) ;
219
- var loss = tf . reduce_mean ( sscel ) ;
220
- var adam = tf . train . AdamOptimizer ( learning_rate ) ;
221
- var optimizer = adam . minimize ( loss , global_step : global_step ) ;
222
- } ) ;
223
-
224
- with ( tf . name_scope ( "accuracy" ) , delegate
225
- {
226
- var correct_predictions = tf . equal ( predictions , y ) ;
227
- var accuracy = tf . reduce_mean ( tf . cast ( correct_predictions , TF_DataType . TF_FLOAT ) , name : "accuracy" ) ;
228
- } ) ;
229
-
230
167
return graph ;
231
168
}
232
169
233
- private bool Train ( Session sess , Graph graph )
170
+ public void Train ( Session sess )
234
171
{
172
+ var graph = tf . get_default_graph ( ) ;
235
173
var stopwatch = Stopwatch . StartNew ( ) ;
236
174
237
175
sess . run ( tf . global_variables_initializer ( ) ) ;
@@ -263,10 +201,7 @@ private bool Train(Session sess, Graph graph)
263
201
loss_value = result [ 2 ] ;
264
202
var step = ( int ) result [ 1 ] ;
265
203
if ( step % 10 == 0 )
266
- {
267
- var estimate = TimeSpan . FromSeconds ( ( stopwatch . Elapsed . TotalSeconds / i ) * total ) ;
268
- Console . WriteLine ( $ "Training on batch { i } /{ total } loss: { loss_value } . Estimated training time: { estimate } ") ;
269
- }
204
+ Console . WriteLine ( $ "Training on batch { i } /{ total } loss: { loss_value . ToString ( "0.0000" ) } .") ;
270
205
271
206
if ( step % 100 == 0 )
272
207
{
@@ -289,7 +224,7 @@ private bool Train(Session sess, Graph graph)
289
224
290
225
var valid_accuracy = sum_accuracy / cnt ;
291
226
292
- print ( $ "\n Validation Accuracy = { valid_accuracy } \n ") ;
227
+ print ( $ "\n Validation Accuracy = { valid_accuracy . ToString ( "P" ) } \n ") ;
293
228
294
229
// Save model
295
230
if ( valid_accuracy > max_accuracy )
@@ -300,13 +235,6 @@ private bool Train(Session sess, Graph graph)
300
235
}
301
236
}
302
237
}
303
-
304
- return max_accuracy > 0.9 ;
305
- }
306
-
307
- public void Train ( Session sess )
308
- {
309
- Train ( sess , sess . graph ) ;
310
238
}
311
239
312
240
public void Predict ( Session sess )
0 commit comments