Skip to content

Commit 4309275

Browse files
committed
fix eager tensor resize. #608
1 parent 7f507fa commit 4309275

File tree

4 files changed

+34
-25
lines changed

4 files changed

+34
-25
lines changed

src/TensorFlowNET.Core/APIs/tf.image.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool ce
203203
=> image_ops_impl.non_max_suppression_padded(boxes, scores, max_output_size, iou_threshold, score_threshold, pad_to_max_output_size,
204204
name, sorted_input, canonicalized_coordinates, tile_size);
205205

206-
public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null)
207-
=> gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name);
206+
public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null)
207+
=> gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, half_pixel_centers: half_pixel_centers, name: name);
208208

209209
public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR,
210210
bool preserve_aspect_ratio = false, string name = null)

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,21 @@ public static Tensor decode_bmp(Tensor contents,
155155
}
156156
}
157157

158-
public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null)
158+
public static Tensor resize_bilinear(Tensor images,
159+
Tensor size,
160+
bool align_corners = false,
161+
bool half_pixel_centers = false,
162+
string name = null)
159163
{
160164
if (tf.Context.executing_eagerly())
161165
{
162-
throw new NotImplementedException("resize_bilinear");
166+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
167+
"ResizeBilinear", name,
168+
null,
169+
images, size,
170+
"align_corners", align_corners,
171+
"half_pixel_centers", half_pixel_centers);
172+
return results[0];
163173
}
164174
else
165175
{

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ public static Tensor _constant_impl(object value,
6464

6565
var num_t = t.TensorShape.num_elements();
6666
if (num_t == shape.num_elements())
67-
throw new NotImplementedException("");
68-
if(num_t == 1)
67+
return _eager_reshape(t, shape, tf.Context);
68+
if (num_t == 1)
6969
{
7070
if (t.dtype == dtypes.@bool)
7171
throw new NotImplementedException("");
@@ -100,6 +100,16 @@ public static Tensor _constant_impl(object value,
100100
return op.outputs[0];
101101
}
102102

103+
private static Tensor _eager_reshape(EagerTensor tensor, int[] shape, Context ctx)
104+
{
105+
var attr_t = tensor.dtype.as_datatype_enum();
106+
var dims_t = convert_to_eager_tensor(shape, ctx, dtypes.int32);
107+
var inputs_flat = new[] { tensor, dims_t };
108+
var attrs = new object[] { "T", attr_t, "Tshape", TF_DataType.TF_INT32 };
109+
var result = tf.Runner.Execute(ctx, "Reshape", 1, inputs_flat, attrs);
110+
return result[0];
111+
}
112+
103113
private static Tensor _eager_fill(int[] dims, EagerTensor value, Context ctx)
104114
{
105115
var attr_t = value.dtype.as_datatype_enum();

test/TensorFlowNET.UnitTest/Basics/TensorTest.cs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -280,36 +280,25 @@ public void boolean_mask()
280280
[TestMethod]
281281
public unsafe void tensor_resize()
282282
{
283+
tf.enable_eager_execution();
284+
283285
var imageArray = new float[256 * 256 * 3];
284286

285287
using var newSize = tf.convert_to_tensor(new int[] { 100, 100 });
286288

287-
using (var t = new Tensor(imageArray, new long[] { 1, 256, 256, 3 }, tf.float32))
289+
using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3)))
288290
{
289291
Assert.IsFalse(t.IsDisposed);
290292
Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize);
291293

292294
using var resized = tf.image.resize_bilinear(t, newSize);
293-
EXPECT_EQ((int)resized.shape[0], 1);
294-
EXPECT_EQ((int)resized.shape[1], 100);
295-
EXPECT_EQ((int)resized.shape[2], 100);
296-
EXPECT_EQ((int)resized.shape[3], 3);
295+
EXPECT_EQ(resized.shape[0], 1);
296+
EXPECT_EQ(resized.shape[1], 100);
297+
EXPECT_EQ(resized.shape[2], 100);
298+
EXPECT_EQ(resized.shape[3], 3);
297299
}
298300

299-
fixed (float* ptr = &imageArray[0])
300-
{
301-
using (var t = new Tensor((IntPtr)ptr, new long[] { imageArray.Length }, tf.float32, 4 * imageArray.Length))
302-
{
303-
Assert.IsFalse(t.IsDisposed);
304-
Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize);
305-
306-
using var resized = tf.image.resize_bilinear(t, newSize);
307-
EXPECT_EQ((int)resized.shape[0], 1);
308-
EXPECT_EQ((int)resized.shape[1], 100);
309-
EXPECT_EQ((int)resized.shape[2], 100);
310-
EXPECT_EQ((int)resized.shape[3], 3);
311-
}
312-
}
301+
tf.compat.v1.disable_eager_execution();
313302
}
314303
}
315304
}

0 commit comments

Comments
 (0)