Skip to content

Commit 1fc3838

Browse files
authored
Merge pull request #299 from acifonelli/master
Add missing `operator -`s
2 parents e1db889 + 1e54c41 commit 1fc3838

File tree

2 files changed

+188
-1
lines changed

2 files changed

+188
-1
lines changed

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ public partial class Tensor
3333

3434
public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1);
3535

36+
public static Tensor operator -(double x, Tensor y) => BinaryOpWrapper("sub", x, y);
37+
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y);
38+
public static Tensor operator -(int x, Tensor y) => BinaryOpWrapper("sub", x, y);
3639
public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y);
3740
public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y);
41+
public static Tensor operator -(Tensor x, float y) => BinaryOpWrapper("sub", x, y);
3842
public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y);
39-
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y);
4043

4144
public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y);
4245
public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y);

test/TensorFlowNET.UnitTest/OperationsTest.cs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,189 @@ public void addOpTests()
215215
}
216216
#endregion
217217
}
218+
219+
[TestMethod]
220+
public void subOpTests()
221+
{
222+
const int rows = 2; // to avoid broadcasting effect
223+
const int cols = 10;
224+
225+
#region intTest
226+
const int firstIntVal = -2;
227+
const int secondIntVal = 3;
228+
229+
var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
230+
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
231+
var intResult = firstIntFeed.Sum() - secondIntFeed.Sum();
232+
var intResultTwo = -firstIntFeed.Sum();
233+
234+
var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
235+
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
236+
var c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));
237+
238+
using (var sess = tf.Session())
239+
{
240+
var o = sess.run(c,
241+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
242+
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
243+
Assert.AreEqual((int)o, intResult);
244+
}
245+
246+
// Testing `operator -(Tensor x, Tensor y)
247+
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
248+
using (var sess = tf.Session())
249+
{
250+
var o = sess.run(c,
251+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
252+
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
253+
Assert.AreEqual((int)o, intResult);
254+
}
255+
256+
// Testing `operator -(Tensor x, int y)
257+
c = tf.reduce_sum(tf.reduce_sum(a - secondIntVal, 1));
258+
using (var sess = tf.Session())
259+
{
260+
var o = sess.run(c,
261+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
262+
Assert.AreEqual((int)o, intResult);
263+
}
264+
265+
// Testing `operator -(int x, Tensor y)
266+
c = tf.reduce_sum(tf.reduce_sum(secondIntVal - a, 1));
267+
using (var sess = tf.Session())
268+
{
269+
var o = sess.run(c,
270+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
271+
Assert.AreEqual((int)o, Math.Abs(intResult));
272+
}
273+
274+
// Testing `operator -(Tensor x)
275+
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
276+
using (var sess = tf.Session())
277+
{
278+
var o = sess.run(c,
279+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
280+
Assert.AreEqual((int)o, intResultTwo);
281+
}
282+
#endregion
283+
284+
#region floatTest
285+
const float firstFloatVal = -2.0f;
286+
const float secondFloatVal = 3.0f;
287+
288+
var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
289+
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
290+
var floatResult = firstFloatFeed.Sum() - secondFloatFeed.Sum();
291+
var floatResultTwo = -firstFloatFeed.Sum();
292+
293+
a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
294+
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
295+
c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));
296+
297+
using (var sess = tf.Session())
298+
{
299+
var o = sess.run(c,
300+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
301+
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
302+
Assert.AreEqual((float)o, floatResult);
303+
}
304+
305+
// Testing `operator -(Tensor x, Tensor y)
306+
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
307+
using (var sess = tf.Session())
308+
{
309+
var o = sess.run(c,
310+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
311+
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
312+
Assert.AreEqual((float)o, floatResult);
313+
}
314+
315+
// Testing `operator -(Tensor x, float y)
316+
c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1));
317+
using (var sess = tf.Session())
318+
{
319+
var o = sess.run(c,
320+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
321+
Assert.AreEqual((float)o, floatResult);
322+
}
323+
324+
// Testing `operator -(float x, Tensor y)
325+
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1));
326+
using (var sess = tf.Session())
327+
{
328+
var o = sess.run(c,
329+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
330+
Assert.AreEqual((float)o, Math.Abs(floatResult));
331+
}
332+
333+
// Testing `operator -(Tensor x)
334+
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
335+
using (var sess = tf.Session())
336+
{
337+
var o = sess.run(c,
338+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
339+
Assert.AreEqual((float)o, floatResultTwo);
340+
}
341+
#endregion
342+
343+
#region doubleTest
344+
const double firstDoubleVal = -2.0;
345+
const double secondDoubleVal = 3.0;
346+
347+
var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
348+
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
349+
var doubleResult = firstDoubleFeed.Sum() - secondDoubleFeed.Sum();
350+
var doubleResultTwo = -firstDoubleFeed.Sum();
351+
352+
a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
353+
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
354+
c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));
355+
356+
using (var sess = tf.Session())
357+
{
358+
var o = sess.run(c,
359+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
360+
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
361+
Assert.AreEqual((double)o, doubleResult);
362+
}
363+
364+
// Testing `operator -(Tensor x, Tensor y)
365+
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
366+
using (var sess = tf.Session())
367+
{
368+
var o = sess.run(c,
369+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
370+
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
371+
Assert.AreEqual((double)o, doubleResult);
372+
}
373+
374+
// Testing `operator -(Tensor x, double y)
375+
c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1));
376+
using (var sess = tf.Session())
377+
{
378+
var o = sess.run(c,
379+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
380+
Assert.AreEqual((double)o, doubleResult);
381+
}
382+
383+
// Testing `operator -(double x, Tensor y)
384+
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1));
385+
using (var sess = tf.Session())
386+
{
387+
var o = sess.run(c,
388+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
389+
Assert.AreEqual((double)o, Math.Abs(doubleResult));
390+
}
391+
392+
// Testing `operator -(Tensor x)
393+
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
394+
using (var sess = tf.Session())
395+
{
396+
var o = sess.run(c,
397+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
398+
Assert.AreEqual((double)o, doubleResultTwo);
399+
}
400+
#endregion
401+
}
218402
}
219403
}

0 commit comments

Comments
 (0)