Skip to content

Commit 6fe5a6c

Browse files
committed
tried to implement MnistModelLoader
1 parent e8e0243 commit 6fe5a6c

File tree

3 files changed

+152
-2
lines changed

3 files changed

+152
-2
lines changed

src/TensorFlowHub/MnistModelLoader.cs

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,154 @@
22
using System.Threading.Tasks;
33
using System.Collections.Generic;
44
using System.Text;
5+
using System.IO;
56
using NumSharp;
67

78
namespace Tensorflow.Hub
89
{
910
public class MnistModelLoader : IModelLoader<MnistDataSet>
1011
{
11-
public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
12+
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
13+
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
14+
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
15+
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
16+
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
17+
18+
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
19+
{
20+
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
21+
throw new ArgumentException("Validation set should be smaller than training set");
22+
23+
var sourceUrl = setting.SourceUrl;
24+
25+
if (string.IsNullOrEmpty(sourceUrl))
26+
sourceUrl = DEFAULT_SOURCE_URL;
27+
28+
// load train images
29+
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
30+
.ShowProgressInConsole(setting.ShowProgressInConsole);
31+
32+
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
33+
.ShowProgressInConsole(setting.ShowProgressInConsole);
34+
35+
var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);
36+
37+
// load train labels
38+
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
39+
.ShowProgressInConsole(setting.ShowProgressInConsole);
40+
41+
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
42+
.ShowProgressInConsole(setting.ShowProgressInConsole);
43+
44+
var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);
45+
46+
// load test images
47+
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
48+
.ShowProgressInConsole(setting.ShowProgressInConsole);
49+
50+
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
51+
.ShowProgressInConsole(setting.ShowProgressInConsole);
52+
53+
var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);
54+
55+
// load test labels
56+
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
57+
.ShowProgressInConsole(setting.ShowProgressInConsole);
58+
59+
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
60+
.ShowProgressInConsole(setting.ShowProgressInConsole);
61+
62+
var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);
63+
64+
var end = trainImages.shape[0];
65+
66+
var validationSize = setting.ValidationSize;
67+
68+
var validationImages = trainImages[np.arange(validationSize)];
69+
var validationLabels = trainLabels[np.arange(validationSize)];
70+
71+
trainImages = trainImages[np.arange(validationSize, end)];
72+
trainLabels = trainLabels[np.arange(validationSize, end)];
73+
74+
var dtype = setting.DtType;
75+
var reshape = setting.ReShape;
76+
77+
var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
78+
var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
79+
var test = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
80+
81+
return new Datasets<MnistDataSet>(train, validation, test);
82+
}
83+
84+
private NDArray ExtractImages(string file, int? limit = null)
85+
{
86+
using (var bytestream = new FileStream(file, FileMode.Open))
87+
{
88+
var magic = Read32(bytestream);
89+
if (magic != 2051)
90+
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
91+
92+
var num_images = Read32(bytestream);
93+
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
94+
95+
var rows = Read32(bytestream);
96+
var cols = Read32(bytestream);
97+
98+
var buf = new byte[rows * cols * num_images];
99+
100+
bytestream.Read(buf, 0, buf.Length);
101+
102+
var data = np.frombuffer(buf, np.uint8);
103+
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
104+
105+
return data;
106+
}
107+
}
108+
109+
private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
110+
{
111+
using (var bytestream = new FileStream(file, FileMode.Open))
112+
{
113+
var magic = Read32(bytestream);
114+
if (magic != 2049)
115+
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
116+
117+
var num_items = Read32(bytestream);
118+
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
119+
120+
var buf = new byte[num_items];
121+
122+
bytestream.Read(buf, 0, buf.Length);
123+
124+
var labels = np.frombuffer(buf, np.uint8);
125+
126+
if (one_hot)
127+
return DenseToOneHot(labels, num_classes);
128+
129+
return labels;
130+
}
131+
}
132+
133+
private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
134+
{
135+
var num_labels = labels_dense.shape[0];
136+
var index_offset = np.arange(num_labels) * num_classes;
137+
var labels_one_hot = np.zeros(num_labels, num_classes);
138+
139+
for(int row = 0; row < num_labels; row++)
140+
{
141+
var col = labels_dense.Data<byte>(row);
142+
labels_one_hot.SetData(1.0, row, col);
143+
}
144+
145+
return labels_one_hot;
146+
}
147+
148+
private uint Read32(FileStream bytestream)
12149
{
13-
throw new NotImplementedException();
150+
var buffer = new byte[sizeof(uint)];
151+
var count = bytestream.Read(buffer, 0, 4);
152+
return np.frombuffer(buffer, ">u4").Data<uint>(0);
14153
}
15154
}
16155
}

src/TensorFlowHub/ModelLoadSetting.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ public class ModelLoadSetting
1515
public int? TrainSize { get; set; }
1616
public int? TestSize { get; set; }
1717
public string SourceUrl { get; set; }
18+
public bool ShowProgressInConsole { get; set; }
1819
}
1920
}

src/TensorFlowHub/Utils.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelL
7171

7272
public static async Task ShowProgressInConsole(this Task task)
7373
{
74+
await ShowProgressInConsole(task, true);
75+
}
76+
77+
public static async Task ShowProgressInConsole(this Task task, bool enable)
78+
{
79+
if (!enable)
80+
{
81+
await task;
82+
}
83+
7484
var cts = new CancellationTokenSource();
7585
var showProgressTask = ShowProgressInConsole(cts);
7686

0 commit comments

Comments
 (0)