Skip to content

Commit 36273fd

Browse files
committed
Fix tf.decode_jpeg and SetOpAttrScalar for float.
1 parent e575583 commit 36273fd

16 files changed

+221
-57
lines changed

src/TensorFlowNET.Core/APIs/tf.image.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System.Collections.Generic;
1818
using Tensorflow.IO;
19+
using static Tensorflow.Binding;
1920

2021
namespace Tensorflow
2122
{
@@ -59,6 +60,10 @@ public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMe
5960
string name = null)
6061
=> image_ops_impl.resize_images(images, size, method, preserve_aspect_ratio, antialias, name);
6162

63+
public Tensor resize_images_v2(Tensor images, TensorShape size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false,
64+
string name = null)
65+
=> image_ops_impl.resize_images(images, tf.constant(size.dims), method, preserve_aspect_ratio, antialias, name);
66+
6267
public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias)
6368
=> image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias);
6469

@@ -160,7 +165,7 @@ public Tensor decode_jpeg(Tensor contents,
160165
int ratio = 1,
161166
bool fancy_upscaling = true,
162167
bool try_recover_truncated = false,
163-
float acceptable_fraction = 1,
168+
int acceptable_fraction = 1,
164169
string dct_method = "",
165170
string name = null)
166171
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
376376
case TF_AttrType.TF_ATTR_INT:
377377
c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value));
378378
break;
379+
case TF_AttrType.TF_ATTR_FLOAT:
380+
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
381+
break;
379382
case TF_AttrType.TF_ATTR_SHAPE:
380383
var dims = (value as int[]).Select(x => (long)x).ToArray();
381384
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
176176
[DllImport(TensorFlowLibName)]
177177
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value);
178178

179+
[DllImport(TensorFlowLibName)]
180+
public static extern void TFE_OpSetAttrFloat(SafeOpHandle op, string attr_name, float value);
181+
179182
/// <summary>
180183
///
181184
/// </summary>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Text;
6+
7+
namespace Tensorflow.Keras.Preprocessings
8+
{
9+
public partial class DatasetUtils
10+
{
11+
/// <summary>
12+
/// Potentially restict samples & labels to a training or validation split.
13+
/// </summary>
14+
/// <param name="samples"></param>
15+
/// <param name="labels"></param>
16+
/// <param name="validation_split"></param>
17+
/// <param name="subset"></param>
18+
/// <returns></returns>
19+
public (T1[], T2[]) get_training_or_validation_split<T1, T2>(T1[] samples,
20+
T2[] labels,
21+
float validation_split,
22+
string subset)
23+
{
24+
var num_val_samples = Convert.ToInt32(samples.Length * validation_split);
25+
if (subset == "training")
26+
{
27+
Console.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
28+
samples = samples[..^num_val_samples];
29+
labels = labels[..^num_val_samples];
30+
}
31+
else if (subset == "validation")
32+
{
33+
Console.WriteLine($"Using {num_val_samples} files for validation.");
34+
samples = samples[samples.Length..];
35+
labels = labels[samples.Length..];
36+
}
37+
else
38+
throw new NotImplementedException("");
39+
40+
return (samples, labels);
41+
}
42+
}
43+
}

src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.index_directory.cs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
36
using System.Text;
47

58
namespace Tensorflow.Keras.Preprocessings
@@ -20,14 +23,45 @@ public partial class DatasetUtils
2023
/// file_paths, labels, class_names
2124
/// </returns>
2225
public (string[], int[], string[]) index_directory(string directory,
23-
string labels,
24-
string[] formats,
25-
string class_names = null,
26+
string[] formats = null,
27+
string[] class_names = null,
2628
bool shuffle = true,
2729
int? seed = null,
2830
bool follow_links = false)
2931
{
30-
throw new NotImplementedException("");
32+
var labels = new List<int>();
33+
var file_paths = new List<string>();
34+
35+
var class_dirs = Directory.GetDirectories(directory);
36+
class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray();
37+
38+
for (var label = 0; label < class_dirs.Length; label++)
39+
{
40+
var files = Directory.GetFiles(class_dirs[label]);
41+
file_paths.AddRange(files);
42+
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
43+
}
44+
45+
var return_labels = new int[labels.Count];
46+
var return_file_paths = new string[file_paths.Count];
47+
48+
if (shuffle)
49+
{
50+
if (!seed.HasValue)
51+
seed = np.random.randint((long)1e6);
52+
var random_index = np.arange(labels.Count);
53+
var rng = np.random.RandomState(seed.Value);
54+
rng.shuffle(random_index);
55+
var index = random_index.ToArray<int>();
56+
57+
for (int i = 0; i< labels.Count; i++)
58+
{
59+
return_labels[i] = labels[index[i]];
60+
return_file_paths[i] = file_paths[index[i]];
61+
}
62+
}
63+
64+
return (return_file_paths, return_labels, class_names);
3165
}
3266
}
3367
}

src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public partial class Preprocessing
2828
public Tensor image_dataset_from_directory(string directory,
2929
string labels = "inferred",
3030
string label_mode = "int",
31-
string class_names = null,
31+
string[] class_names = null,
3232
string color_mode = "rgb",
3333
int batch_size = 32,
3434
TensorShape image_size = null,
@@ -44,13 +44,15 @@ public Tensor image_dataset_from_directory(string directory,
4444
num_channels = 3;
4545
// C:/Users/haipi/.keras/datasets/flower_photos
4646
var (image_paths, label_list, class_name_list) = tf.keras.preprocessing.dataset_utils.index_directory(directory,
47-
labels,
48-
WHITELIST_FORMATS,
47+
formats: WHITELIST_FORMATS,
4948
class_names: class_names,
5049
shuffle: shuffle,
5150
seed: seed,
5251
follow_links: follow_links);
5352

53+
(image_paths, label_list) = tf.keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
54+
55+
paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
5456
throw new NotImplementedException("");
5557
}
5658
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Globalization;
3+
using static Tensorflow.Binding;
4+
5+
namespace Tensorflow.Keras
6+
{
7+
public partial class Preprocessing
8+
{
9+
public Tensor paths_and_labels_to_dataset(string[] image_paths,
10+
TensorShape image_size,
11+
int num_channels,
12+
int[] labels,
13+
string label_mode,
14+
int num_classes,
15+
string interpolation)
16+
{
17+
foreach (var image_path in image_paths)
18+
path_to_image(image_path, image_size, num_channels, interpolation);
19+
20+
throw new NotImplementedException("");
21+
}
22+
23+
Tensor path_to_image(string path, TensorShape image_size, int num_channels, string interpolation)
24+
{
25+
var img = tf.io.read_file(path);
26+
img = tf.image.decode_image(
27+
img, channels: num_channels, expand_animations: false);
28+
img = tf.image.resize_images_v2(img, image_size, method: interpolation);
29+
return img;
30+
}
31+
}
32+
}

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static Operation Assert(Tensor condition, object[] data, int? summarize =
8686

8787
var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard");
8888

89-
return guarded_assert[0].op;
89+
return guarded_assert == null ? null : guarded_assert[0].op;
9090
});
9191
}
9292

@@ -423,8 +423,6 @@ public static Tensor cond(Tensor pred,
423423
return true_fn() as Tensor;
424424
else
425425
return false_fn() as Tensor;
426-
427-
return null;
428426
}
429427

430428
// Add the Switch to the graph.
@@ -507,8 +505,6 @@ public static Tensor[] cond<T>(Tensor pred,
507505
return true_fn() as Tensor[];
508506
else
509507
return false_fn() as Tensor[];
510-
511-
return null;
512508
}
513509

514510
// Add the Switch to the graph.

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,24 @@ public static Tensor decode_jpeg(Tensor contents,
6666
int ratio = 1,
6767
bool fancy_upscaling = true,
6868
bool try_recover_truncated = false,
69-
float acceptable_fraction = 1,
69+
int acceptable_fraction = 1,
7070
string dct_method = "",
7171
string name = null)
7272
{
7373
// Add nodes to the TensorFlow graph.
7474
if (tf.Context.executing_eagerly())
7575
{
76-
throw new NotImplementedException("decode_jpeg");
76+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
77+
"DecodeJpeg", name,
78+
null,
79+
contents,
80+
"channels", channels,
81+
"ratio", ratio,
82+
"fancy_upscaling", fancy_upscaling,
83+
"try_recover_truncated", try_recover_truncated,
84+
"acceptable_fraction", acceptable_fraction,
85+
"dct_method", dct_method);
86+
return results[0];
7787
}
7888
else
7989
{
@@ -171,17 +181,42 @@ public static Tensor resize_bilinear(Tensor images,
171181
"half_pixel_centers", half_pixel_centers);
172182
return results[0];
173183
}
174-
else
184+
185+
var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new
175186
{
176-
var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new
177-
{
178-
images,
179-
size,
180-
align_corners
181-
});
187+
images,
188+
size,
189+
align_corners
190+
});
182191

183-
return _op.outputs[0];
192+
return _op.outputs[0];
193+
}
194+
195+
public static Tensor resize_bicubic(Tensor images,
196+
Tensor size,
197+
bool align_corners = false,
198+
bool half_pixel_centers = false,
199+
string name = null)
200+
{
201+
if (tf.Context.executing_eagerly())
202+
{
203+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
204+
"ResizeBicubic", name,
205+
null,
206+
images, size,
207+
"align_corners", align_corners,
208+
"half_pixel_centers", half_pixel_centers);
209+
return results[0];
184210
}
211+
212+
var _op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, args: new
213+
{
214+
images,
215+
size,
216+
align_corners
217+
});
218+
219+
return _op.outputs[0];
185220
}
186221

187222
public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false,

src/TensorFlowNET.Core/Operations/gen_ops.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24715,13 +24715,12 @@ public static Tensor resize_area (Tensor images, Tensor size, bool? align_corner
2471524715
/// <remarks>
2471624716
/// Input images can be of different types but output images are always float.
2471724717
/// </remarks>
24718-
public static Tensor resize_bicubic (Tensor images, Tensor size, bool? align_corners = null, string name = "ResizeBicubic")
24718+
public static Tensor resize_bicubic (Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = "ResizeBicubic")
2471924719
{
2472024720
var dict = new Dictionary<string, object>();
2472124721
dict["images"] = images;
2472224722
dict["size"] = size;
24723-
if (align_corners.HasValue)
24724-
dict["align_corners"] = align_corners.Value;
24723+
dict["align_corners"] = align_corners;
2472524724
var op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, keywords: dict);
2472624725
return op.output;
2472724726
}

0 commit comments

Comments
 (0)