@@ -298,5 +298,45 @@ public unsafe void tensor_resize()
298
298
299
299
tf . compat . v1 . disable_eager_execution ( ) ;
300
300
}
301
+
302
+ /// <summary>
303
+ /// Assign tensor to slice of other tensor.
304
+ /// </summary>
305
+ [ TestMethod ]
306
+ public void TestAssignOfficial ( )
307
+ {
308
+ // example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
309
+
310
+ // python
311
+ // import tensorflow as tf
312
+ // A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
313
+ // with tf.compat.v1.Session() as sess:
314
+ // sess.run(tf.compat.v1.global_variables_initializer())
315
+ // print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]
316
+
317
+ // op = A[:2,:2].assign(22. * tf.ones((2, 2)))
318
+ // print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
319
+
320
+ // C#
321
+ // [[1,2,3], [4,5,6], [7,8,9]]
322
+ double [ ] [ ] initial = new double [ ] [ ]
323
+ {
324
+ new double [ ] { 1 , 2 , 3 } ,
325
+ new double [ ] { 4 , 5 , 6 } ,
326
+ new double [ ] { 7 , 8 , 9 }
327
+ } ;
328
+ Tensor A = tf . Variable ( initial , dtype : tf . float32 ) ;
329
+ // Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
330
+ Tensor result1 = A [ ":2" , ":2" ] ;
331
+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 1 , 2 } , result1 [ 0 ] . ToArray < double > ( ) ) ) ;
332
+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 4 , 5 } , result1 [ 1 ] . ToArray < double > ( ) ) ) ;
333
+
334
+ // An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
335
+ Tensor op = A [ ":2" , ":2" ] . assign ( 22.0 * tf . ones ( ( 2 , 2 ) ) ) ;
336
+ // Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
337
+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 22 , 22 , 3 } , op [ 0 ] . ToArray < double > ( ) ) ) ;
338
+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 22 , 22 , 6 } , op [ 1 ] . ToArray < double > ( ) ) ) ;
339
+ Assert . IsTrue ( Enumerable . SequenceEqual ( new double [ ] { 7 , 8 , 9 } , op [ 2 ] . ToArray < double > ( ) ) ) ;
340
+ }
301
341
}
302
342
}
0 commit comments