Skip to content

Commit fd32b84

Browse files
committed
Add ZipDataset, fix paths_and_labels_to_dataset.
1 parent 36273fd commit fd32b84

File tree

13 files changed

+162
-13
lines changed

13 files changed

+162
-13
lines changed

src/TensorFlowNET.Core/Data/DatasetManager.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ public IDatasetV2 from_tensor(Tensor tensors)
2525
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
2626
=> new TensorSliceDataset(features, labels);
2727

28+
public IDatasetV2 from_tensor_slices(NDArray array)
29+
=> new TensorSliceDataset(array);
30+
2831
public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64)
2932
=> new RangeDataset(count, output_type: output_type);
3033

3134
public IDatasetV2 range(int start, int stop, int step = 1, TF_DataType output_type = TF_DataType.TF_INT64)
3235
=> new RangeDataset(stop, start: start, step: step, output_type: output_type);
36+
37+
public IDatasetV2 zip(params IDatasetV2[] ds)
38+
=> new ZipDataset(ds);
3339
}
3440
}

src/TensorFlowNET.Core/Data/TensorSliceDataset.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ namespace Tensorflow.Data
1111
{
1212
public class TensorSliceDataset : DatasetSource
1313
{
14+
public TensorSliceDataset(NDArray array)
15+
{
16+
var element = tf.constant(array);
17+
_tensors = new[] { element };
18+
var batched_spec = new[] { element.ToTensorSpec() };
19+
structure = batched_spec.Select(x => x._unbatch()).ToArray();
20+
21+
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
22+
}
23+
1424
public TensorSliceDataset(Tensor features, Tensor labels)
1525
{
1626
_tensors = new[] { features, labels };
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public class ZipDataset : DatasetV2
9+
{
10+
dataset_ops ops = new dataset_ops();
11+
public ZipDataset(params IDatasetV2[] ds)
12+
{
13+
var input_datasets = ds.Select(x => x.variant_tensor).ToArray();
14+
structure = ds.Select(x => x.structure[0]).ToArray();
15+
variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes);
16+
}
17+
}
18+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class RescalingArgs : LayerArgs
8+
{
9+
public float Scale { get; set; }
10+
public float Offset { get; set; }
11+
}
12+
}

src/TensorFlowNET.Core/Keras/KerasApi.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,17 @@ public Tensor Input(TensorShape shape = null,
7272

7373
public class LayersApi
7474
{
75-
public Layer Dense(int units,
75+
public Rescaling Rescaling(float scale,
76+
float offset = 0,
77+
TensorShape input_shape = null)
78+
=> new Rescaling(new RescalingArgs
79+
{
80+
Scale = scale,
81+
Offset = offset,
82+
InputShape = input_shape
83+
});
84+
85+
public Dense Dense(int units,
7686
Activation activation = null,
7787
TensorShape input_shape = null)
7888
=> new Dense(new DenseArgs
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
/// <summary>
10+
/// Multiply inputs by `scale` and adds `offset`.
11+
/// </summary>
12+
public class Rescaling : Layer
13+
{
14+
RescalingArgs args;
15+
Tensor scale;
16+
Tensor offset;
17+
18+
public Rescaling(RescalingArgs args) : base(args)
19+
{
20+
this.args = args;
21+
}
22+
23+
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
24+
{
25+
scale = math_ops.cast(args.Scale, args.DType);
26+
throw new NotImplementedException("");
27+
}
28+
}
29+
}
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using static Tensorflow.Binding;
45

56
namespace Tensorflow.Keras.Preprocessings
67
{
78
public partial class DatasetUtils
89
{
9-
10+
public IDatasetV2 labels_to_dataset(int[] labels, string label_mode, int num_classes)
11+
{
12+
var label_ds = tf.data.Dataset.from_tensor_slices(labels);
13+
if (label_mode == "binary")
14+
throw new NotImplementedException("");
15+
else if(label_mode == "categorical")
16+
throw new NotImplementedException("");
17+
return label_ds;
18+
}
1019
}
1120
}

src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ public partial class DatasetUtils
3131
else if (subset == "validation")
3232
{
3333
Console.WriteLine($"Using {num_val_samples} files for validation.");
34-
samples = samples[samples.Length..];
35-
labels = labels[samples.Length..];
34+
samples = samples[(samples.Length - num_val_samples)..];
35+
labels = labels[(labels.Length - num_val_samples)..];
3636
}
3737
else
3838
throw new NotImplementedException("");

src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public partial class Preprocessing
2525
/// <param name="interpolation"></param>
2626
/// <param name="follow_links"></param>
2727
/// <returns></returns>
28-
public Tensor image_dataset_from_directory(string directory,
28+
public IDatasetV2 image_dataset_from_directory(string directory,
2929
string labels = "inferred",
3030
string label_mode = "int",
3131
string[] class_names = null,
@@ -52,8 +52,11 @@ public Tensor image_dataset_from_directory(string directory,
5252

5353
(image_paths, label_list) = tf.keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
5454

55-
paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
56-
throw new NotImplementedException("");
55+
var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
56+
if (shuffle)
57+
dataset = dataset.shuffle(batch_size * 8, seed: seed);
58+
dataset = dataset.batch(batch_size);
59+
return dataset;
5760
}
5861
}
5962
}

src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,44 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Globalization;
4+
using System.Threading.Tasks;
35
using static Tensorflow.Binding;
46

57
namespace Tensorflow.Keras
68
{
79
public partial class Preprocessing
810
{
9-
public Tensor paths_and_labels_to_dataset(string[] image_paths,
11+
public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
1012
TensorShape image_size,
1113
int num_channels,
1214
int[] labels,
1315
string label_mode,
1416
int num_classes,
1517
string interpolation)
1618
{
17-
foreach (var image_path in image_paths)
18-
path_to_image(image_path, image_size, num_channels, interpolation);
19+
Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels);
20+
Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}");
21+
var data = np.zeros(shape, NPTypeCode.Float);
1922

20-
throw new NotImplementedException("");
23+
for (var i = 0; i < image_paths.Length; i++)
24+
{
25+
var image = path_to_image(image_paths[i], image_size, num_channels, interpolation);
26+
data[i] = image.numpy();
27+
if (i % 100 == 0)
28+
Console.WriteLine($"Filled {i}/{image_paths.Length} data into memory.");
29+
}
30+
31+
var img_ds = tf.data.Dataset.from_tensor_slices(data);
32+
33+
if (label_mode == "int")
34+
{
35+
var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
36+
img_ds = tf.data.Dataset.zip(img_ds, label_ds);
37+
}
38+
else
39+
throw new NotImplementedException("");
40+
41+
return img_ds;
2142
}
2243

2344
Tensor path_to_image(string path, TensorShape image_size, int num_channels, string interpolation)

0 commit comments

Comments
 (0)