Skip to content

Commit 76350f7

Browse files
committed
Remove unnecessary parameters for RecordGradient.
1 parent aea62f6 commit 76350f7

File tree

11 files changed

+65
-29
lines changed

11 files changed

+65
-29
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
using System.Diagnostics;
1919
using System.Linq;
2020
using Tensorflow.Eager;
21+
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow.Contexts
2324
{
@@ -114,6 +115,36 @@ public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tenso
114115
}
115116
}
116117

118+
[DebuggerStepThrough]
119+
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
120+
Func<Tensors> eagerAction,
121+
Action<Operation> recordGradient,
122+
Tensors tensors)
123+
{
124+
var shouldRunInEager = executing_eagerly()
125+
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length;
126+
127+
if (shouldRunInEager)
128+
return eagerAction();
129+
else
130+
{
131+
if (executing_eagerly())
132+
{
133+
graph_mode();
134+
var result = graphAction();
135+
restore_mode();
136+
return result;
137+
}
138+
else
139+
{
140+
var result = graphAction();
141+
if (tf.Runner.MustRecordGradient())
142+
recordGradient(result[0].op);
143+
return result;
144+
}
145+
}
146+
}
147+
117148
public void Dispose()
118149
=> Handle.Dispose();
119150
}

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public partial class EagerRunner
1111
public bool RecordGradient(string op_name,
1212
Tensor[] inputs,
1313
object[] attrs,
14-
Tensor[] results)
14+
Tensor[] results,
15+
Func<BackwardFunction> getBackwardFunction = null)
1516
{
1617
var input_ids = MakeTensorIDList(inputs);
1718
var input_dtypes = MakeTensorDtypeList(inputs);
@@ -77,13 +78,20 @@ public bool RecordGradient(string op_name,
7778
else
7879
op_inputs = inputs;
7980

80-
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes,
81-
() => GetGradientFunction(op_name, inputs, attrs, results));
82-
81+
TapeSetRecordOperation(op_name, inputs, results,
82+
getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results));
8383

8484
return true;
8585
}
8686

87+
Func<BackwardFunction> GetBackwradFunction(string op_name,
88+
Tensor[] op_inputs,
89+
object[] attrs,
90+
Tensor[] op_outputs)
91+
{
92+
return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs);
93+
}
94+
8795
BackwardFunction GetGradientFunction(string op_name,
8896
Tensor[] op_inputs,
8997
object[] attrs,

src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ public partial class EagerRunner
1010
void TapeSetRecordBackprop(string op_type,
1111
Tensor[] input_tensors,
1212
TapeTensor[] output_tensors,
13-
long[] input_ids,
14-
TF_DataType[] input_dtypes,
1513
Func<BackwardFunction> backward_function_getter)
1614
{
1715
if (!CouldBackprop())
@@ -22,7 +20,6 @@ void TapeSetRecordBackprop(string op_type,
2220
foreach (var tape in tf.GetTapeSet())
2321
{
2422
tape.RecordOperation(op_type, input_tensors, output_tensors,
25-
input_ids, input_dtypes,
2623
backward_function_getter);
2724
}
2825
}

src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ public partial class EagerRunner
99
bool TapeSetRecordForwardprop(string op_type,
1010
Tensor[] input_tensors,
1111
TapeTensor[] output_tensors,
12-
long[] input_ids,
13-
TF_DataType[] input_dtypes,
1412
Func<BackwardFunction> backward_function_getter)
1513
{
1614
if (!CouldForwardprop())

src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ namespace Tensorflow.Eager
77
{
88
public partial class EagerRunner
99
{
10-
bool TapeSetRecordOperation(string op_type,
10+
public bool TapeSetRecordOperation(string op_type,
1111
Tensor[] input_tensors,
1212
Tensor[] output_tensors,
13-
long[] input_ids,
14-
TF_DataType[] input_dtypes,
1513
Func<BackwardFunction> backward_function_getter)
1614
{
1715
var output_info = new List<TapeTensor>();
@@ -20,11 +18,11 @@ bool TapeSetRecordOperation(string op_type,
2018
return false;
2119

2220
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(),
23-
input_ids, input_dtypes, backward_function_getter))
21+
backward_function_getter))
2422
return false;
2523

2624
TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(),
27-
input_ids, input_dtypes, backward_function_getter);
25+
backward_function_getter);
2826

2927
return true;
3028
}

src/TensorFlowNET.Core/Eager/IEagerRunner.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using Tensorflow.Contexts;
33
using Tensorflow.Gradients;
4+
using static Tensorflow.tensorflow;
45

56
namespace Tensorflow.Eager
67
{
@@ -37,7 +38,8 @@ Tensor[] TFE_TapeGradient(ITape tape,
3738
bool RecordGradient(string op_name,
3839
Tensor[] inputs,
3940
object[] attrs,
40-
Tensor[] results);
41+
Tensor[] results,
42+
Func<BackwardFunction> getBackwardFunction = null);
4143

4244
bool MustRecordGradient();
4345

src/TensorFlowNET.Core/Gradients/Tape.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes)
5454
if (tensor_tape_.find(tensor_ids[i]))
5555
{
5656
if (IsDtypeTrainable(dtypes[i]))
57-
{
58-
tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}");
5957
return true;
60-
}
6158
}
6259
}
6360
return false;

src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Linq;
2020
using System.Reflection;
2121
using Tensorflow.Gradients;
22+
using static Tensorflow.Binding;
2223

2324
namespace Tensorflow
2425
{
@@ -47,11 +48,17 @@ public static void RegisterFromAssembly()
4748
{
4849
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
4950
(oper, out_grads) =>
50-
g.InvokeMember(m.Name,
51-
BindingFlags.InvokeMethod,
52-
null,
53-
null,
54-
args: new object[] { oper, out_grads }) as Tensor[]
51+
{
52+
tf.Logger.Debug($"Caculate Gradient: {m.Name}");
53+
var results = g.InvokeMember(m.Name,
54+
BindingFlags.InvokeMethod,
55+
null,
56+
null,
57+
args: new object[] { oper, out_grads }) as Tensor[];
58+
foreach (var result in results.Where(x => x != null))
59+
tf.Logger.Debug($"{result.TensorShape}");
60+
return results;
61+
}
5562
);
5663
}
5764

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGra
358358
}
359359

360360
var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy });
361-
361+
362362
return op.output;
363363
}
364364

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,10 @@ Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask =
335335

336336
var layer_inputs = node.MapArguments(tensor_dict);
337337

338-
tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}");
338+
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}");
339339
var outputs = node.Layer.Apply(layer_inputs, is_training: training);
340-
340+
foreach (var output in outputs.Where(x => x != null))
341+
tf.Logger.Debug($"{output.TensorShape}");
341342
// Update tensor_dict for next input
342343
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
343344
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));

0 commit comments

Comments
 (0)