Skip to content

Commit d75366c

Browse files
committed
fix Variable[slice].assign() #653
1 parent 6f8beab commit d75366c

File tree

13 files changed

+271
-161
lines changed

13 files changed

+271
-161
lines changed

src/TensorFlowNET.Core/APIs/tf.ops.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ public void add_to_collection<T>(string name, T value)
2727
public void add_to_collections<T>(List<string> names, T value)
2828
=> get_default_graph().add_to_collections(names, value);
2929

30-
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
31-
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
32-
3330
public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
3431
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
3532

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void restore_mode()
9191
context_switches.Pop();
9292
}
9393

94-
[DebuggerStepThrough]
94+
// [DebuggerStepThrough]
9595
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors)
9696
{
9797
var shouldRunInEager = executing_eagerly()
@@ -115,7 +115,7 @@ public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tenso
115115
}
116116
}
117117

118-
[DebuggerStepThrough]
118+
// [DebuggerStepThrough]
119119
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
120120
Func<Tensors> eagerAction,
121121
Action<Operation> recordGradient,

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,30 @@ public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tenso
593593
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
594594
input, begin, end, strides);
595595

596+
public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
597+
int begin_mask = 0,
598+
int end_mask = 0,
599+
int ellipsis_mask = 0,
600+
int new_axis_mask = 0,
601+
int shrink_axis_mask = 0,
602+
string name = null)
603+
=> tf.Context.RunInAutoMode(()
604+
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new
605+
{
606+
input, begin, end, strides, value,
607+
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
608+
}).output, ()
609+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
610+
"ResourceStridedSliceAssign", name,
611+
null,
612+
input, begin, end, strides, value,
613+
"begin_mask", begin_mask,
614+
"end_mask", end_mask,
615+
"ellipsis_mask", ellipsis_mask,
616+
"new_axis_mask", new_axis_mask,
617+
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
618+
input, begin, end, strides, value);
619+
596620
public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides,
597621
int begin_mask = 0,
598622
int end_mask = 0,

src/TensorFlowNET.Core/Operations/resource_variable_ops.cs

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,43 +34,6 @@ public static ITensorOrOperation shape_safe_assign_variable_handle(Tensor handle
3434
name: name);
3535
}
3636

37-
/// <summary>
38-
///
39-
/// </summary>
40-
/// <param name="self"></param>
41-
/// <param name="value"></param>
42-
/// <param name="use_locking"></param>
43-
/// <param name="read_value"></param>
44-
/// <returns>
45-
/// If `read_value` is `True`, this method will return the new value of the
46-
/// variable after the assignment has completed.Otherwise, when in graph mode
47-
/// it will return the `Operation` that does the assignment, and when in eager
48-
/// mode it will return `None`.
49-
/// </returns>
50-
public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true)
51-
{
52-
var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype);
53-
self.assert_is_compatible_with(value_tensor);
54-
var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name);
55-
if (read_value)
56-
{
57-
return self._lazy_read(assign_op);
58-
}
59-
60-
return assign_op;
61-
}
62-
63-
public static Operation _lazy_read(this Tensor self, Operation op)
64-
{
65-
variable_accessed(self);
66-
throw new NotImplementedException();
67-
}
68-
69-
public static void variable_accessed(this Tensor variable)
70-
{
71-
throw new NotImplementedException();
72-
}
73-
7437
public static bool is_resource_variable(IVariableV1 var)
7538
{
7639
return var is ResourceVariable;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class ParsedSliceArgs
8+
{
9+
public int[] Begin { get; set; }
10+
public Tensor PackedBegin { get; set; }
11+
public int[] End { get; set; }
12+
public Tensor PackedEnd { get; set; }
13+
public int[] Strides { get; set; }
14+
public Tensor PackedStrides { get; set; }
15+
public int BeginMask { get; set; }
16+
public int EndMask { get; set; }
17+
public int ShrinkAxisMask { get; set; }
18+
public int NewAxisMask { get; set; }
19+
public int EllipsisMask { get; set; }
20+
}
21+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using NumSharp;
2+
3+
namespace Tensorflow
4+
{
5+
public partial class Tensor
6+
{
7+
/// <summary>
8+
/// Used to keep the original variable when slicing
9+
/// </summary>
10+
public ResourceVariable OriginalVar { get; set; }
11+
public ParsedSliceArgs OriginalVarSlice { get; set; }
12+
13+
public ResourceVariable assign(Tensor tensor)
14+
{
15+
if (OriginalVar != null)
16+
{
17+
OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice);
18+
return OriginalVar;
19+
}
20+
else
21+
throw new RuntimeError("Operation doesn't support.");
22+
}
23+
}
24+
}

src/TensorFlowNET.Core/Tensors/Tensor.Index.cs

Lines changed: 11 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -30,81 +30,28 @@ public Tensor this[params Slice[] slices]
3030
{
3131
get
3232
{
33-
var begin = new List<int>();
34-
var end = new List<int>();
35-
var strides = new List<int>();
33+
var args = tensor_util.ParseSlices(slices);
3634

37-
var index = 0;
38-
var (new_axis_mask, shrink_axis_mask) = (0, 0);
39-
var (begin_mask, end_mask) = (0, 0);
40-
var ellipsis_mask = 0;
41-
42-
foreach (var s in slices)
43-
{
44-
if (s.IsNewAxis)
45-
{
46-
begin.Add(0);
47-
end.Add(0);
48-
strides.Add(1);
49-
new_axis_mask |= (1 << index);
50-
}
51-
else if (s.IsEllipsis)
52-
{
53-
begin.Add(0);
54-
end.Add(0);
55-
strides.Add(1);
56-
ellipsis_mask |= (1 << index);
57-
}
58-
else
59-
{
60-
if (s.Start.HasValue)
61-
{
62-
begin.Add(s.Start.Value);
63-
}
64-
else
65-
{
66-
begin.Add(0);
67-
begin_mask |= (1 << index);
68-
}
69-
70-
if (s.Stop.HasValue)
71-
{
72-
end.Add(s.Stop.Value);
73-
}
74-
else
75-
{
76-
end.Add(0);
77-
end_mask |= (1 << index);
78-
}
79-
80-
strides.Add(s.Step);
81-
if (s.IsIndex)
82-
shrink_axis_mask |= (1 << index);
83-
}
84-
85-
index += 1;
86-
}
87-
88-
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
35+
return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
8936
{
9037
string name = scope;
91-
if (begin != null)
38+
if (args.Begin != null)
9239
{
9340
var (packed_begin, packed_end, packed_strides) =
94-
(array_ops.stack(begin.ToArray()),
95-
array_ops.stack(end.ToArray()),
96-
array_ops.stack(strides.ToArray()));
41+
(array_ops.stack(args.Begin),
42+
array_ops.stack(args.End),
43+
array_ops.stack(args.Strides));
9744

9845
return gen_array_ops.strided_slice(
9946
this,
10047
packed_begin,
10148
packed_end,
10249
packed_strides,
103-
begin_mask: begin_mask,
104-
end_mask: end_mask,
105-
shrink_axis_mask: shrink_axis_mask,
106-
new_axis_mask: new_axis_mask,
107-
ellipsis_mask: ellipsis_mask,
50+
begin_mask: args.BeginMask,
51+
end_mask: args.EndMask,
52+
shrink_axis_mask: args.ShrinkAxisMask,
53+
new_axis_mask: args.NewAxisMask,
54+
ellipsis_mask: args.EllipsisMask,
10855
name: name);
10956
}
11057

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using NumSharp;
1818
using System;
19+
using System.Collections.Generic;
1920
using System.Linq;
2021
using System.Text;
2122
using Tensorflow.Eager;
@@ -584,5 +585,75 @@ public static string to_numpy_string(Tensor tensor)
584585
return nd.ToString();
585586
}
586587
}
588+
589+
public static ParsedSliceArgs ParseSlices(Slice[] slices)
590+
{
591+
var begin = new List<int>();
592+
var end = new List<int>();
593+
var strides = new List<int>();
594+
595+
var index = 0;
596+
var (new_axis_mask, shrink_axis_mask) = (0, 0);
597+
var (begin_mask, end_mask) = (0, 0);
598+
var ellipsis_mask = 0;
599+
600+
foreach (var s in slices)
601+
{
602+
if (s.IsNewAxis)
603+
{
604+
begin.Add(0);
605+
end.Add(0);
606+
strides.Add(1);
607+
new_axis_mask |= (1 << index);
608+
}
609+
else if (s.IsEllipsis)
610+
{
611+
begin.Add(0);
612+
end.Add(0);
613+
strides.Add(1);
614+
ellipsis_mask |= (1 << index);
615+
}
616+
else
617+
{
618+
if (s.Start.HasValue)
619+
{
620+
begin.Add(s.Start.Value);
621+
}
622+
else
623+
{
624+
begin.Add(0);
625+
begin_mask |= (1 << index);
626+
}
627+
628+
if (s.Stop.HasValue)
629+
{
630+
end.Add(s.Stop.Value);
631+
}
632+
else
633+
{
634+
end.Add(0);
635+
end_mask |= (1 << index);
636+
}
637+
638+
strides.Add(s.Step);
639+
if (s.IsIndex)
640+
shrink_axis_mask |= (1 << index);
641+
}
642+
643+
index += 1;
644+
}
645+
646+
return new ParsedSliceArgs
647+
{
648+
Begin = begin.ToArray(),
649+
End = end.ToArray(),
650+
Strides = strides.ToArray(),
651+
BeginMask = begin_mask,
652+
EndMask = end_mask,
653+
EllipsisMask = ellipsis_mask,
654+
ShrinkAxisMask = shrink_axis_mask,
655+
NewAxisMask = new_axis_mask
656+
};
657+
}
587658
}
588659
}

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,22 @@ public Tensor assign<T>(T value, bool use_locking = false, string name = null, b
8989
return assign_op;
9090
}
9191

92+
public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
93+
{
94+
_strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
95+
}
96+
97+
void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
98+
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
99+
{
100+
var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
101+
begin_mask: begin_mask,
102+
end_mask: end_mask,
103+
ellipsis_mask: ellipsis_mask,
104+
new_axis_mask: new_axis_mask,
105+
shrink_axis_mask: shrink_axis_mask);
106+
}
107+
92108
public IVariableV1 assign_lazy_load(Tensor value, string name = null)
93109
{
94110
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);

0 commit comments

Comments
 (0)