Skip to content

Commit 6f8beab

Browse files
BanycOceania2018
authored andcommitted
Add test case for Tensor.assign #653
The test case is currently not passed yet.
1 parent fed264a commit 6f8beab

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

test/TensorFlowNET.UnitTest/Basics/TensorTest.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,5 +298,45 @@ public unsafe void tensor_resize()
298298

299299
tf.compat.v1.disable_eager_execution();
300300
}
301+
302+
/// <summary>
303+
/// Assign tensor to slice of other tensor.
304+
/// </summary>
305+
[TestMethod]
306+
public void TestAssignOfficial()
307+
{
308+
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
309+
310+
// python
311+
// import tensorflow as tf
312+
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
313+
// with tf.compat.v1.Session() as sess:
314+
// sess.run(tf.compat.v1.global_variables_initializer())
315+
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]
316+
317+
// op = A[:2,:2].assign(22. * tf.ones((2, 2)))
318+
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
319+
320+
// C#
321+
// [[1,2,3], [4,5,6], [7,8,9]]
322+
double[][] initial = new double[][]
323+
{
324+
new double[] { 1, 2, 3 },
325+
new double[] { 4, 5, 6 },
326+
new double[] { 7, 8, 9 }
327+
};
328+
Tensor A = tf.Variable(initial, dtype: tf.float32);
329+
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
330+
Tensor result1 = A[":2", ":2"];
331+
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>()));
332+
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>()));
333+
334+
// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
335+
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2)));
336+
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
337+
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>()));
338+
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>()));
339+
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>()));
340+
}
301341
}
302342
}

0 commit comments

Comments
 (0)