Skip to content

Commit 8f42762

Browse files
committed
Add control_inputs in Operation #141
1 parent 6663724 commit 8f42762

File tree

12 files changed

+327
-27
lines changed

12 files changed

+327
-27
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Eager;
6+
7+
namespace Tensorflow
8+
{
9+
public partial class Graph
10+
{
11+
public Context _control_flow_context;
12+
13+
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
14+
public Queue<_ControlDependenciesController> _control_dependencies_stack
15+
{
16+
get
17+
{
18+
return _graph_control_dependencies_stack;
19+
}
20+
set
21+
{
22+
_graph_control_dependencies_stack = value;
23+
}
24+
}
25+
26+
/// <summary>
27+
/// For an op that takes `input_ops` as inputs, compute control inputs.
28+
/// </summary>
29+
/// <param name="input_ops">The data input ops for an op to be created.</param>
30+
/// <returns>A list of control inputs for the op to be created.</returns>
31+
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
32+
{
33+
Operation[] ret = new Operation[0];
34+
35+
foreach(var controller in _control_dependencies_stack)
36+
{
37+
bool dominated = false;
38+
// If any of the input_ops already depends on the inputs from controller,
39+
// we say that the new op is dominated (by that input), and we therefore
40+
// do not need to add control dependencies for this controller's inputs.
41+
foreach(var op in input_ops)
42+
{
43+
if (controller.op_in_group(op))
44+
{
45+
dominated = true;
46+
break;
47+
}
48+
}
49+
50+
if (!dominated)
51+
ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray();
52+
}
53+
54+
return ret;
55+
}
56+
57+
public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
58+
{
59+
if (control_inputs == null)
60+
return new _ControlDependenciesController(this, null);
61+
62+
var control_ops = new List<Operation>();
63+
foreach (var c in control_inputs)
64+
{
65+
control_ops.Add(c);
66+
}
67+
68+
return new _ControlDependenciesController(this, control_ops);
69+
}
70+
71+
/// <summary>
72+
/// Returns the current control flow context.
73+
/// </summary>
74+
/// <returns>A context object.</returns>
75+
public Context _get_control_flow_context()
76+
{
77+
return _control_flow_context;
78+
}
79+
80+
/// <summary>
81+
/// Sets the current control flow context.
82+
/// </summary>
83+
/// <param name="ctx">a context object.</param>
84+
public void _set_control_flow_context(Context ctx)
85+
{
86+
_control_flow_context = ctx;
87+
}
88+
89+
public void _push_control_dependencies_controller(_ControlDependenciesController controller)
90+
{
91+
_control_dependencies_stack.Enqueue(controller);
92+
}
93+
94+
public void _pop_control_dependencies_controller(_ControlDependenciesController controller)
95+
{
96+
_control_dependencies_stack.Dequeue();
97+
}
98+
99+
public void _record_op_seen_by_control_dependencies(Operation op)
100+
{
101+
foreach (var controller in _control_dependencies_stack)
102+
controller.add_op(op);
103+
}
104+
}
105+
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,9 @@ public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataTy
142142
return op;
143143
}
144144

145-
/// <summary>
146-
/// For an op that takes `input_ops` as inputs, compute control inputs.
147-
/// </summary>
148-
/// <param name="input_ops">The data input ops for an op to be created.</param>
149-
/// <returns>A list of control inputs for the op to be created.</returns>
150-
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
151-
{
152-
return new Operation[0];
153-
}
154-
155145
private void _create_op_helper(Operation op, bool compute_device = true)
156146
{
157-
147+
_record_op_seen_by_control_dependencies(op);
158148
}
159149

160150
public void _add_op(Operation op)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Eager;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
/// Context manager for `control_dependencies()`
10+
/// </summary>
11+
public class _ControlDependenciesController : IPython
12+
{
13+
private Graph _graph;
14+
private List<Operation> _control_inputs_val;
15+
private List<Operation> _seen_nodes;
16+
private Queue<_ControlDependenciesController> _old_stack;
17+
private bool _new_stack;
18+
private Context _old_control_flow_context;
19+
20+
public Operation[] control_inputs => _control_inputs_val.ToArray();
21+
22+
public _ControlDependenciesController(Graph graph, List<Operation> control_inputs)
23+
{
24+
_graph = graph;
25+
if (control_inputs == null)
26+
{
27+
_control_inputs_val = new List<Operation>();
28+
_new_stack = true;
29+
}
30+
else
31+
{
32+
_control_inputs_val = control_inputs;
33+
_new_stack = false;
34+
}
35+
36+
_seen_nodes = new List<Operation>();
37+
}
38+
39+
public void add_op(Operation op)
40+
{
41+
_seen_nodes.Add(op);
42+
}
43+
44+
public bool op_in_group(Operation op)
45+
{
46+
return _seen_nodes.Contains(op);
47+
}
48+
49+
public void __enter__()
50+
{
51+
if (_new_stack)
52+
{
53+
// Clear the control_dependencies graph.
54+
_old_stack = _graph._control_dependencies_stack;
55+
_graph._control_dependencies_stack = new Queue<_ControlDependenciesController>();
56+
57+
// Clear the control_flow_context too.
58+
_old_control_flow_context = _graph._get_control_flow_context();
59+
_graph._set_control_flow_context(null);
60+
}
61+
62+
_graph._push_control_dependencies_controller(this);
63+
}
64+
65+
public void __exit__()
66+
{
67+
_graph._pop_control_dependencies_controller(this);
68+
if (_new_stack)
69+
{
70+
_graph._control_dependencies_stack = _old_stack;
71+
_graph._set_control_flow_context(_old_control_flow_context);
72+
}
73+
}
74+
75+
public void Dispose()
76+
{
77+
78+
}
79+
}
80+
}

src/TensorFlowNET.Core/Operations/Operation.Input.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ public InputList inputs
3939

4040
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
4141

42+
public Operation[] control_inputs
43+
{
44+
get
45+
{
46+
return GetControlInputs();
47+
}
48+
}
49+
4250
public unsafe Operation[] GetControlInputs()
4351
{
4452
var control_inputs = new Operation[NumControlInputs];

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,53 @@ public Operation(Graph g, string opType, string oper_name)
4949
c_api.TF_FinishOperation(desc, status);
5050
}
5151

52-
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
52+
/// <summary>
53+
/// Creates an `Operation`.
54+
/// </summary>
55+
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
56+
/// <param name="g">`Graph`. The parent graph.</param>
57+
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
58+
/// <param name="output_types">list of `DType` objects.</param>
59+
/// <param name="control_inputs">
60+
/// list of operations or tensors from which to have a
61+
/// control dependency.
62+
/// </param>
63+
/// <param name="input_types">
64+
/// List of `DType` objects representing the
65+
/// types of the tensors accepted by the `Operation`. By default
66+
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
67+
/// reference-typed inputs must specify these explicitly.
68+
/// </param>
69+
/// <param name="original_op"></param>
70+
/// <param name="op_def"></param>
71+
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
5372
{
5473
Graph = g;
5574

75+
// Build the list of control inputs.
76+
var control_input_ops = new List<Operation>();
77+
if(control_inputs != null)
78+
{
79+
foreach(var c in control_inputs)
80+
{
81+
switch (c)
82+
{
83+
case Operation c1:
84+
control_input_ops.Add(c1);
85+
break;
86+
default:
87+
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
88+
}
89+
}
90+
}
91+
92+
// This will be set by self.inputs.
93+
5694
_id_value = Graph._next_id();
5795
if(op_def == null)
5896
op_def = g.GetOpDef(node_def.Op);
5997

60-
_handle = ops._create_c_op(g, node_def, inputs);
98+
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
6199

62100
output_types = new TF_DataType[NumOutputs];
63101

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ public partial class c_api
3434
[DllImport(TensorFlowLibName)]
3535
public static extern void TF_AddInput(IntPtr desc, TF_Output input);
3636

37+
/// <summary>
38+
/// Call once per control input to `desc`.
39+
/// </summary>
40+
/// <param name="desc">TF_OperationDescription*</param>
41+
/// <param name="input">TF_Operation*</param>
42+
[DllImport(TensorFlowLibName)]
43+
public static extern void TF_AddControlInput(IntPtr desc, IntPtr input);
44+
3745
/// <summary>
3846
/// For inputs that take a list of tensors.
3947
/// inputs must point to TF_Output[num_inputs].

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ public static Operation group(List<Operation> inputs, string name = "")
1313
{
1414
name = namescope;
1515

16-
var ops_on_device = new Dictionary<string, Operation[]>();
17-
1816
// Sorts *inputs according to their devices.
17+
var ops_on_device = new Dictionary<string, Operation[]>();
1918
foreach (var inp in inputs)
2019
{
2120
ops_on_device[inp.Device] = new Operation[] { inp };
@@ -24,7 +23,9 @@ public static Operation group(List<Operation> inputs, string name = "")
2423
// 1-level tree. The root node is the returned NoOp node.
2524
if (ops_on_device.Count == 1)
2625
{
27-
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name);
26+
var dev = ops_on_device.Keys.First();
27+
var deps = ops_on_device.Values.First();
28+
return _GroupControlDeps(dev, deps, name);
2829
}
2930

3031
// 2-level tree. The root node is the returned NoOp node.
@@ -35,12 +36,21 @@ public static Operation group(List<Operation> inputs, string name = "")
3536

3637
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
3738
{
38-
if (string.IsNullOrEmpty(dev))
39+
Operation result = null;
40+
41+
Python.with(ops.control_dependencies(deps), delegate
3942
{
40-
return gen_control_flow_ops.no_op(name);
41-
}
43+
if (string.IsNullOrEmpty(dev))
44+
{
45+
result = gen_control_flow_ops.no_op(name);
46+
}
47+
else
48+
{
49+
result = gen_control_flow_ops.no_op(name);
50+
}
51+
});
4252

43-
return null;
53+
return result;
4454
}
4555
}
4656
}

src/TensorFlowNET.Core/Python.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,30 @@ protected void print(object obj)
1313
{
1414
Console.WriteLine(obj.ToString());
1515
}
16+
17+
public static void with(IPython py, Action action)
18+
{
19+
try
20+
{
21+
py.__enter__();
22+
action();
23+
}
24+
catch (Exception ex)
25+
{
26+
throw ex;
27+
}
28+
finally
29+
{
30+
py.__exit__();
31+
py.Dispose();
32+
}
33+
}
34+
}
35+
36+
public interface IPython : IDisposable
37+
{
38+
void __enter__();
39+
40+
void __exit__();
1641
}
1742
}

src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,10 @@ public static implicit operator Tensor(RefVariable var)
2020
{
2121
return var._AsTensor();
2222
}
23+
24+
public static implicit operator RefVariable(Tensor var)
25+
{
26+
return null;
27+
}
2328
}
2429
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,10 @@ private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, O
166166
// Recursively build initializer expressions for inputs.
167167
return op;
168168
}
169+
170+
public override string ToString()
171+
{
172+
return $"tf.Variable '{name}' shape={shape} dtype={dtype}";
173+
}
169174
}
170175
}

0 commit comments

Comments
 (0)