Skip to content

Commit c616eea

Browse files
acifonelliOceania2018
authored andcommitted
Adding BatchMatMul gradient (#304)
Unit testing the gradient too.
1 parent 6fe5a6c commit c616eea

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,38 @@ public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
168168
return new Tensor[] { grad_a, grad_b };
169169
}
170170

171+
[RegisterGradient("BatchMatMul")]
171172
public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads)
172173
{
173-
throw new NotImplementedException();
174+
var grad = grads[0];
175+
Tensor grad_a = null, grad_b = null;
176+
177+
var t_a = (bool)op.get_attr("adj_x");
178+
var t_b = (bool)op.get_attr("adj_y");
179+
var a = math_ops.conj(op.inputs[0]);
180+
var b = math_ops.conj(op.inputs[1]);
181+
if (!t_a && !t_b)
182+
{
183+
grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true);
184+
grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true);
185+
}
186+
else if (!t_a && t_b)
187+
{
188+
grad_a = gen_math_ops.batch_mat_mul(grad, b);
189+
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
190+
}
191+
else if (t_a && !t_b)
192+
{
193+
grad_a = gen_math_ops.batch_mat_mul(grad, b);
194+
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true);
195+
}
196+
else if (t_a && t_b)
197+
{
198+
grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true);
199+
grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true);
200+
}
201+
202+
return new Tensor[] { grad_a, grad_b };
174203
}
175204

176205
[RegisterGradient("Mean")]

test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
using NumSharp;
57
using Tensorflow;
68
using static Tensorflow.Python;
79

@@ -30,6 +32,39 @@ public void testGradients()
3032
});
3133
}
3234

35+
[TestMethod]
36+
public void testBatchMatMulGradient()
37+
{
38+
var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape:new []{2, 3, 3});
39+
var b = tf.divide(a, tf.constant(2.0f));
40+
var c = tf.batch_matmul(a, b);
41+
var g = tf.gradients(c, new[] {a, b}, stop_gradients: new[] {a, b});
42+
var checkG = new[]
43+
{
44+
3.0f, 7.5f, 12.0f,
45+
3.0f, 7.5f, 12.0f,
46+
3.0f, 7.5f, 12.0f,
47+
16.5f, 21.0f, 25.5f,
48+
16.5f, 21.0f, 25.5f,
49+
16.5f, 21.0f, 25.5f,
50+
12.0f, 12.0f, 12.0f,
51+
15.0f, 15.0f, 15.0f,
52+
18.0f, 18.0f, 18.0f,
53+
39.0f, 39.0f, 39.0f,
54+
42.0f, 42.0f, 42.0f,
55+
45.0f, 45.0f, 45.0f
56+
};
57+
using (var sess = tf.Session())
58+
{
59+
var result = sess.run(g);
60+
var resultList = result[0].GetData<float>().ToList();
61+
resultList.AddRange(result[1].GetData<float>());
62+
Console.WriteLine(result.ToString());
63+
CollectionAssert.AreEqual(resultList.ToArray(), checkG);
64+
}
65+
}
66+
67+
3368
[Ignore("TODO")]
3469
[TestMethod]
3570
public void testUnusedOutput()

0 commit comments

Comments
 (0)