Skip to content

Commit 6663724

Browse files
committed
seperate Input and Output implementation from Operation
1 parent 1b347a8 commit 6663724

File tree

4 files changed

+128
-108
lines changed

4 files changed

+128
-108
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Runtime.InteropServices;
5+
using System.Text;
6+
7+
namespace Tensorflow
8+
{
9+
public partial class Operation
10+
{
11+
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
12+
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
13+
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
14+
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
15+
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
16+
17+
private InputList _inputs;
18+
public InputList inputs
19+
{
20+
get
21+
{
22+
if (_inputs == null)
23+
{
24+
var retval = new Tensor[NumInputs];
25+
26+
for (int i = 0; i < NumInputs; i++)
27+
{
28+
var tf_outpus = Input(i);
29+
var op = new Operation(tf_outpus.oper);
30+
retval[i] = op.outputs[tf_outpus.index];
31+
}
32+
33+
_inputs = new InputList(retval);
34+
}
35+
36+
return _inputs;
37+
}
38+
}
39+
40+
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
41+
42+
public unsafe Operation[] GetControlInputs()
43+
{
44+
var control_inputs = new Operation[NumControlInputs];
45+
46+
if (NumControlInputs > 0)
47+
{
48+
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
49+
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
50+
for (int i = 0; i < NumControlInputs; i++)
51+
{
52+
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
53+
control_inputs[i] = new Operation(*(IntPtr*)handle);
54+
}
55+
}
56+
57+
return control_inputs;
58+
}
59+
}
60+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class Operation
9+
{
10+
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
11+
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
12+
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
13+
14+
private Tensor[] _outputs;
15+
public Tensor[] outputs
16+
{
17+
get
18+
{
19+
if (_outputs == null)
20+
{
21+
_outputs = new Tensor[NumOutputs];
22+
23+
for (int i = 0; i < NumOutputs; i++)
24+
_outputs[i] = new Tensor(this, i, OutputType(i));
25+
}
26+
27+
return _outputs;
28+
}
29+
}
30+
31+
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
32+
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
33+
34+
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
35+
{
36+
int size = Marshal.SizeOf<TF_Input>();
37+
var handle = Marshal.AllocHGlobal(size);
38+
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
39+
var consumers = new TF_Input[num];
40+
for (int i = 0; i < num; i++)
41+
{
42+
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
43+
}
44+
45+
return consumers;
46+
}
47+
48+
public unsafe Operation[] GetControlOutputs()
49+
{
50+
var control_outputs = new Operation[NumControlOutputs];
51+
52+
if (NumControlOutputs > 0)
53+
{
54+
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
55+
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
56+
for (int i = 0; i < NumControlInputs; i++)
57+
{
58+
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
59+
control_outputs[i] = new Operation(*(IntPtr*)handle);
60+
}
61+
}
62+
63+
return control_outputs;
64+
}
65+
}
66+
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Tensorflow
88
{
9-
public class Operation
9+
public partial class Operation
1010
{
1111
private readonly IntPtr _handle;
1212

@@ -20,112 +20,6 @@ public class Operation
2020
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
2121
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
2222

23-
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
24-
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
25-
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
26-
27-
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
28-
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
29-
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
30-
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
31-
32-
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
33-
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
34-
{
35-
int size = Marshal.SizeOf<TF_Input>();
36-
var handle = Marshal.AllocHGlobal(size);
37-
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
38-
var consumers = new TF_Input[num];
39-
for (int i = 0; i < num; i++)
40-
{
41-
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
42-
}
43-
44-
return consumers;
45-
}
46-
47-
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
48-
49-
public unsafe Operation[] GetControlInputs()
50-
{
51-
var control_inputs = new Operation[NumControlInputs];
52-
53-
if (NumControlInputs > 0)
54-
{
55-
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
56-
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
57-
for (int i = 0; i < NumControlInputs; i++)
58-
{
59-
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
60-
control_inputs[i] = new Operation(*(IntPtr*)handle);
61-
}
62-
}
63-
64-
return control_inputs;
65-
}
66-
67-
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
68-
69-
public unsafe Operation[] GetControlOutputs()
70-
{
71-
var control_outputs = new Operation[NumControlOutputs];
72-
73-
if (NumControlOutputs > 0)
74-
{
75-
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
76-
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
77-
for (int i = 0; i < NumControlInputs; i++)
78-
{
79-
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
80-
control_outputs[i] = new Operation(*(IntPtr*)handle);
81-
}
82-
}
83-
84-
return control_outputs;
85-
}
86-
87-
private Tensor[] _outputs;
88-
public Tensor[] outputs
89-
{
90-
get
91-
{
92-
if (_outputs == null)
93-
{
94-
_outputs = new Tensor[NumOutputs];
95-
96-
for (int i = 0; i < NumOutputs; i++)
97-
_outputs[i] = new Tensor(this, i, OutputType(i));
98-
}
99-
100-
return _outputs;
101-
}
102-
}
103-
104-
private InputList _inputs;
105-
public InputList inputs
106-
{
107-
get
108-
{
109-
if (_inputs == null)
110-
{
111-
var retval = new Tensor[NumInputs];
112-
113-
for (int i = 0; i < NumInputs; i++)
114-
{
115-
var tf_outpus = Input(i);
116-
var op = new Operation(tf_outpus.oper);
117-
retval[i] = op.outputs[tf_outpus.index];
118-
}
119-
120-
_inputs = new InputList(retval);
121-
}
122-
123-
return _inputs;
124-
}
125-
}
126-
127-
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
128-
12923
private NodeDef _node_def;
13024
public NodeDef node_def
13125
{

test/TensorFlowNET.UnitTest/VariableTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void Add()
4949

5050
using (var session = tf.Session())
5151
{
52-
var sm = session.run(model);
52+
session.run(x.initializer);
5353
for(int i = 0; i < 5; i++)
5454
{
5555
var x1 = x + 1;

0 commit comments

Comments
 (0)