Skip to content

Commit 4f3599e

Browse files
authored
Merge pull request #296 from acifonelli/master
Add missing `operator +`s
2 parents 7cb1746 + e13c444 commit 4f3599e

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ namespace Tensorflow
2323
{
2424
public partial class Tensor
2525
{
26+
public static Tensor operator +(double x, Tensor y) => BinaryOpWrapper("add", x, y);
27+
public static Tensor operator +(float x, Tensor y) => BinaryOpWrapper("add", x, y);
28+
public static Tensor operator +(int x, Tensor y) => BinaryOpWrapper("add", x, y);
2629
public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y);
2730
public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y);
31+
public static Tensor operator +(Tensor x, float y) => BinaryOpWrapper("add", x, y);
32+
public static Tensor operator +(Tensor x, double y) => BinaryOpWrapper("add", x, y);
2833

2934
public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1);
3035

test/TensorFlowNET.UnitTest/OperationsTest.cs

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using System.Text;
6+
using NumSharp;
67
using Tensorflow;
78
using Buffer = Tensorflow.Buffer;
89

@@ -60,5 +61,159 @@ public void addInConstant()
6061
Assert.AreEqual((float)o, 9.0f);
6162
}
6263
}
64+
65+
[TestMethod]
66+
public void addOpTests()
67+
{
68+
const int rows = 2; // to avoid broadcasting effect
69+
const int cols = 10;
70+
71+
#region intTest
72+
const int firstIntVal = 2;
73+
const int secondIntVal = 3;
74+
75+
var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
76+
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
77+
var intResult = firstIntFeed.Sum() + secondIntFeed.Sum();
78+
79+
var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
80+
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
81+
var c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
82+
83+
using (var sess = tf.Session())
84+
{
85+
var o = sess.run(c,
86+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
87+
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
88+
Assert.AreEqual((int)o, intResult);
89+
}
90+
91+
// Testing `operator +(Tensor x, Tensor y)`
92+
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
93+
using (var sess = tf.Session())
94+
{
95+
var o = sess.run(c,
96+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
97+
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
98+
Assert.AreEqual((int)o, intResult);
99+
}
100+
101+
// Testing `operator +(Tensor x, int y)`
102+
c = tf.reduce_sum(tf.reduce_sum(a + secondIntVal, 1));
103+
using (var sess = tf.Session())
104+
{
105+
var o = sess.run(c,
106+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
107+
Assert.AreEqual((int)o, intResult);
108+
}
109+
110+
// Testing `operator +(int x, Tensor y)`
111+
c = tf.reduce_sum(tf.reduce_sum(secondIntVal + a, 1));
112+
using (var sess = tf.Session())
113+
{
114+
var o = sess.run(c,
115+
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
116+
Assert.AreEqual((int)o, intResult);
117+
}
118+
#endregion
119+
120+
#region floatTest
121+
const float firstFloatVal = 2.0f;
122+
const float secondFloatVal = 3.0f;
123+
124+
var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
125+
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
126+
var floatResult = firstFloatFeed.Sum() + secondFloatFeed.Sum();
127+
128+
a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
129+
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
130+
c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
131+
132+
using (var sess = tf.Session())
133+
{
134+
var o = sess.run(c,
135+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
136+
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
137+
Assert.AreEqual((float)o, floatResult);
138+
}
139+
140+
// Testing `operator +(Tensor x, Tensor y)
141+
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
142+
using (var sess = tf.Session())
143+
{
144+
var o = sess.run(c,
145+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
146+
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
147+
Assert.AreEqual((float)o, floatResult);
148+
}
149+
150+
// Testing `operator +(Tensor x, float y)
151+
c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
152+
using (var sess = tf.Session())
153+
{
154+
var o = sess.run(c,
155+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
156+
Assert.AreEqual((float)o, floatResult);
157+
}
158+
159+
// Testing `operator +(float x, Tensor y)
160+
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
161+
using (var sess = tf.Session())
162+
{
163+
var o = sess.run(c,
164+
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
165+
Assert.AreEqual((float)o, floatResult);
166+
}
167+
#endregion
168+
169+
#region doubleTest
170+
const double firstDoubleVal = 2.0;
171+
const double secondDoubleVal = 3.0;
172+
173+
var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
174+
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
175+
var doubleResult = firstDoubleFeed.Sum() + secondDoubleFeed.Sum();
176+
177+
a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
178+
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
179+
c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
180+
181+
using (var sess = tf.Session())
182+
{
183+
var o = sess.run(c,
184+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
185+
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
186+
Assert.AreEqual((double)o, doubleResult);
187+
}
188+
189+
// Testing `operator +(Tensor x, Tensor y)
190+
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
191+
using (var sess = tf.Session())
192+
{
193+
var o = sess.run(c,
194+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
195+
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
196+
Assert.AreEqual((double)o, doubleResult);
197+
}
198+
199+
// Testing `operator +(Tensor x, double y)
200+
c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
201+
using (var sess = tf.Session())
202+
{
203+
var o = sess.run(c,
204+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
205+
Assert.AreEqual((double)o, doubleResult);
206+
}
207+
208+
// Testing `operator +(double x, Tensor y)
209+
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
210+
using (var sess = tf.Session())
211+
{
212+
var o = sess.run(c,
213+
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
214+
Assert.AreEqual((double)o, doubleResult);
215+
}
216+
#endregion
217+
}
63218
}
64219
}

0 commit comments

Comments
 (0)