Skip to content

Commit bf3d1e9

Browse files
committed
Whooray, finally make AdamOptimizer work as expected.
1 parent 04c80f3 commit bf3d1e9

File tree

11 files changed

+275
-69
lines changed

11 files changed

+275
-69
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Text;
44
using Tensorflow.Operations;
55
using Tensorflow.Operations.Activation;
6+
using static Tensorflow.Python;
67

78
namespace Tensorflow
89
{
@@ -101,6 +102,25 @@ public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = nu
101102
Tensor logits = null, string name = null)
102103
=> nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name);
103104

105+
/// <summary>
106+
/// Computes softmax cross entropy between `logits` and `labels`.
107+
/// </summary>
108+
/// <param name="labels"></param>
109+
/// <param name="logits"></param>
110+
/// <param name="dim"></param>
111+
/// <param name="name"></param>
112+
/// <returns></returns>
113+
public static Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, int dim = -1, string name = null)
114+
{
115+
with(ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", new { logits, labels }), scope =>
116+
{
117+
name = scope;
118+
labels = array_ops.stop_gradient(labels, name: "labels_stop_gradient");
119+
});
120+
121+
return softmax_cross_entropy_with_logits_v2(labels, logits, axis: dim, name: name);
122+
}
123+
104124
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
105125
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
106126
}

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,28 @@ public Operation _apply_op_helper(string op_type_name, string name = null, Dicti
9494
if (attrs.ContainsKey(input_arg.TypeAttr))
9595
dtype = (DataType)attrs[input_arg.TypeAttr];
9696
else
97-
if (values is Tensor[] values1)
98-
dtype = values1[0].dtype.as_datatype_enum();
97+
switch (values)
98+
{
99+
case Tensor[] values1:
100+
dtype = values1[0].dtype.as_datatype_enum();
101+
break;
102+
case object[] values1:
103+
foreach(var t in values1)
104+
if(t is Tensor tensor)
105+
{
106+
dtype = tensor.dtype.as_datatype_enum();
107+
break;
108+
}
109+
break;
110+
default:
111+
throw new NotImplementedException($"can't infer the dtype for {values.GetType()}");
112+
}
99113

100114
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
101115
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
102116
}
103117

104-
if(input_arg.IsRef && dtype != DataType.DtInvalid)
118+
if(!input_arg.IsRef && dtype != DataType.DtInvalid)
105119
dtype = dtype.as_base_dtype();
106120

107121
values = ops.internal_convert_n_to_tensor(values,

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ public partial class Operation
1717

1818
private Tensor[] _outputs;
1919
public Tensor[] outputs => _outputs;
20-
#if GRAPH_SERIALIZE
21-
[JsonIgnore]
22-
#endif
20+
2321
public Tensor output => _outputs.FirstOrDefault();
2422

2523
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
using Google.Protobuf.Collections;
2-
#if GRAPH_SERIALIZE
3-
using Newtonsoft.Json;
4-
#endif
52
using System;
63
using System.Collections.Generic;
74
using System.Linq;
@@ -37,21 +34,11 @@ public partial class Operation : ITensorOrOperation
3734
private Graph _graph;
3835
public string type => OpType;
3936

40-
#if GRAPH_SERIALIZE
41-
[JsonIgnore]
4237
public Graph graph => _graph;
43-
[JsonIgnore]
4438
public int _id => _id_value;
45-
[JsonIgnore]
4639
public int _id_value;
47-
[JsonIgnore]
4840
public Operation op => this;
49-
#else
50-
public Graph graph => _graph;
51-
public int _id => _id_value;
52-
public int _id_value;
53-
public Operation op => this;
54-
#endif
41+
5542
public TF_DataType dtype => TF_DataType.DtInvalid;
5643
private Status status = new Status();
5744

@@ -60,9 +47,6 @@ public partial class Operation : ITensorOrOperation
6047
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
6148

6249
private NodeDef _node_def;
63-
#if GRAPH_SERIALIZE
64-
[JsonIgnore]
65-
#endif
6650
public NodeDef node_def
6751
{
6852
get

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,18 @@ public static Tensor concat(Tensor[] values, int axis, string name = "concat")
492492
{
493493
return with(ops.name_scope(name), scope => {
494494
var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32);
495-
return identity(values[0], name = scope);
495+
return identity(values[0], name: scope);
496496
});
497497
}
498498

499499
return gen_array_ops.concat_v2(values, axis, name: name);
500500
}
501501

502+
public static Tensor concat(object[] values, int axis, string name = "concat")
503+
{
504+
return gen_array_ops.concat_v2(values, axis, name: name);
505+
}
506+
502507
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
503508
=> gen_array_ops.gather_v2(@params, indices, axis, name: name);
504509

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public static class gen_array_ops
1919
/// <param name="axis"></param>
2020
/// <param name="name"></param>
2121
/// <returns></returns>
22-
public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
22+
public static Tensor concat_v2<T>(T[] values, int axis, string name = null)
2323
{
2424
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });
2525

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Operations;
56
using static Tensorflow.Python;
@@ -159,8 +160,9 @@ public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
159160
int axis = -1,
160161
string name = null)
161162
{
162-
return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope =>
163+
return with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { logits, labels }), scope =>
163164
{
165+
name = scope;
164166
var precise_logits = logits;
165167
var input_rank = array_ops.rank(precise_logits);
166168
var shape = logits.TensorShape;
@@ -170,6 +172,10 @@ public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
170172

171173
var input_shape = array_ops.shape(precise_logits);
172174

175+
// Make precise_logits and labels into matrices.
176+
precise_logits = _flatten_outer_dims(precise_logits);
177+
labels = _flatten_outer_dims(labels);
178+
173179
// Do the actual op computation.
174180
// The second output tensor contains the gradients. We use it in
175181
// _CrossEntropyGrad() in nn_grad but not here.
@@ -186,5 +192,50 @@ public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
186192
return cost;
187193
});
188194
}
195+
196+
/// <summary>
197+
/// Flattens logits' outer dimensions and keep its last dimension.
198+
/// </summary>
199+
/// <param name="logits"></param>
200+
/// <returns></returns>
201+
private static Tensor _flatten_outer_dims(Tensor logits)
202+
{
203+
var rank = array_ops.rank(logits);
204+
var last_dim_size = array_ops.slice(array_ops.shape(logits),
205+
new[] { math_ops.subtract(rank, 1) },
206+
new[] { 1 });
207+
208+
var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
209+
var output = array_ops.reshape(logits, ops);
210+
211+
// Set output shape if known.
212+
// if not context.executing_eagerly():
213+
var shape = logits.TensorShape;
214+
if(shape != null && shape.NDim > 0)
215+
{
216+
var product = 1;
217+
var product_valid = true;
218+
foreach(var d in shape.Dimensions.Take(shape.NDim - 1))
219+
{
220+
if(d == -1)
221+
{
222+
product_valid = false;
223+
break;
224+
}
225+
else
226+
{
227+
product *= d;
228+
}
229+
}
230+
231+
if (product_valid)
232+
{
233+
var output_shape = new[] { product };
234+
throw new NotImplementedException("_flatten_outer_dims product_valid");
235+
}
236+
}
237+
238+
return output;
239+
}
189240
}
190241
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,11 @@ public partial class Tensor : IDisposable, ITensorOrOperation
2222

2323
private int _id;
2424
private Operation _op;
25-
#if GRAPH_SERIALIZE
26-
[JsonIgnore]
27-
public int Id => _id;
28-
[JsonIgnore]
29-
public Graph graph => op?.graph;
30-
[JsonIgnore]
31-
public Operation op => _op;
32-
[JsonIgnore]
33-
public Tensor[] outputs => op.outputs;
34-
#else
25+
3526
public int Id => _id;
3627
public Graph graph => op?.graph;
3728
public Operation op => _op;
3829
public Tensor[] outputs => op.outputs;
39-
#endif
4030

4131
/// <summary>
4232
/// The string name of this tensor.
@@ -50,18 +40,12 @@ public partial class Tensor : IDisposable, ITensorOrOperation
5040

5141
private TF_DataType _dtype = TF_DataType.DtInvalid;
5242
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
53-
#if GRAPH_SERIALIZE
54-
[JsonIgnore]
55-
#endif
43+
5644
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
57-
#if GRAPH_SERIALIZE
58-
[JsonIgnore]
59-
#endif
45+
6046
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
6147
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
62-
#if GRAPH_SERIALIZE
63-
[JsonIgnore]
64-
#endif
48+
6549
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
6650
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
6751

@@ -70,9 +54,6 @@ public partial class Tensor : IDisposable, ITensorOrOperation
7054
/// <summary>
7155
/// used for keep other pointer when do implicit operating
7256
/// </summary>
73-
#if GRAPH_SERIALIZE
74-
[JsonIgnore]
75-
#endif
7657
public object Tag { get; set; }
7758

7859
public int[] shape
@@ -140,9 +121,7 @@ public int rank
140121
}
141122
}
142123
}
143-
#if GRAPH_SERIALIZE
144-
[JsonIgnore]
145-
#endif
124+
146125
public int NDims => rank;
147126

148127
public string Device => op.Device;

src/TensorFlowNET.Core/Train/AdamOptimizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ public override Operation _finish(Operation[] update_ops, string name_scope)
110110
var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking);
111111

112112
operations.Add(update_beta1);
113-
operations.Add(update_beta1);
113+
operations.Add(update_beta2);
114114
});
115115

116116
return control_flow_ops.group(operations.ToArray(), name: name_scope);

test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ public bool Run()
4949

5050
public Graph BuildGraph()
5151
{
52-
var g = tf.Graph();
53-
5452
// Placeholders for inputs (x) and outputs(y)
5553
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
5654
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
@@ -60,15 +58,16 @@ public Graph BuildGraph()
6058
// Create a fully-connected layer with n_classes nodes as output layer
6159
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
6260
// Define the loss function, optimizer, and accuracy
63-
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss");
61+
var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
62+
loss = tf.reduce_mean(logits, name: "loss");
6463
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
6564
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
6665
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
6766

6867
// Network predictions
6968
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
7069

71-
return g;
70+
return tf.get_default_graph();
7271
}
7372

7473
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
@@ -93,16 +92,10 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
9392
return layer;
9493
}
9594

96-
public Graph ImportGraph()
97-
{
98-
throw new NotImplementedException();
99-
}
100-
101-
public bool Predict()
102-
{
103-
throw new NotImplementedException();
104-
}
95+
public Graph ImportGraph() => throw new NotImplementedException();
10596

97+
public bool Predict() => throw new NotImplementedException();
98+
10699
public void PrepareData()
107100
{
108101
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
@@ -112,7 +105,6 @@ public bool Train()
112105
{
113106
// Number of training iterations in each epoch
114107
var num_tr_iter = mnist.train.labels.len / batch_size;
115-
116108
return with(tf.Session(), sess =>
117109
{
118110
var init = tf.global_variables_initializer();
@@ -153,10 +145,9 @@ public bool Train()
153145
print("---------------------------------------------------------");
154146
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
155147
print("---------------------------------------------------------");
156-
157148
}
158149

159-
return accuracy_val > 0.9;
150+
return accuracy_val > 0.95;
160151
});
161152
}
162153

0 commit comments

Comments
 (0)