Skip to content

Commit ce94a5c

Browse files
committed
base classes for Tensorflow.Hub
1 parent d674d51 commit ce94a5c

File tree

9 files changed

+160
-10
lines changed

9 files changed

+160
-10
lines changed

src/TensorFlowHub/Class1.cs

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/TensorFlowHub/DataSetBase.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp;
5+
6+
namespace Tensorflow.Hub
7+
{
8+
public abstract class DataSetBase : IDataSet
9+
{
10+
public NDArray Data { get; protected set; }
11+
public NDArray Labels { get; protected set; }
12+
}
13+
}

src/TensorFlowHub/Datasets.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp;
5+
6+
namespace Tensorflow.Hub
7+
{
8+
public class Datasets<TDataSet> where TDataSet : IDataSet
9+
{
10+
public TDataSet Train { get; private set; }
11+
12+
public TDataSet Validation { get; private set; }
13+
14+
public TDataSet Test { get; private set; }
15+
16+
public Datasets(TDataSet train, TDataSet validation, TDataSet test)
17+
{
18+
Train = train;
19+
Validation = validation;
20+
Test = test;
21+
}
22+
23+
public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
24+
{
25+
var perm = np.random.permutation(y.shape[0]);
26+
np.random.shuffle(perm);
27+
return (x[perm], y[perm]);
28+
}
29+
30+
/// <summary>
31+
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
32+
/// </summary>
33+
/// <param name="x"></param>
34+
/// <param name="y"></param>
35+
/// <param name="start"></param>
36+
/// <param name="end"></param>
37+
/// <returns></returns>
38+
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
39+
{
40+
var slice = new Slice(start, end);
41+
var x_batch = x[slice];
42+
var y_batch = y[slice];
43+
return (x_batch, y_batch);
44+
}
45+
}
46+
}

src/TensorFlowHub/IDataSet.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp;
5+
6+
namespace Tensorflow.Hub
7+
{
8+
public interface IDataSet
9+
{
10+
NDArray Data { get; }
11+
NDArray Labels { get; }
12+
}
13+
}

src/TensorFlowHub/IModelLoader.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Threading.Tasks;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using NumSharp;
6+
7+
namespace Tensorflow.Hub
8+
{
9+
public interface IModelLoader<TDataSet>
10+
where TDataSet : IDataSet
11+
{
12+
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting);
13+
}
14+
}

src/TensorFlowHub/MnistDataSet.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp;
5+
using Tensorflow;
6+
7+
namespace Tensorflow.Hub
8+
{
9+
public class MnistDataSet : DataSetBase
10+
{
11+
public int NumOfExamples { get; private set; }
12+
public int EpochsCompleted { get; private set; }
13+
public int IndexInEpoch { get; private set; }
14+
15+
public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
16+
{
17+
EpochsCompleted = 0;
18+
IndexInEpoch = 0;
19+
20+
NumOfExamples = images.shape[0];
21+
22+
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
23+
images.astype(dtype.as_numpy_datatype());
24+
images = np.multiply(images, 1.0f / 255.0f);
25+
Data = images;
26+
27+
labels.astype(dtype.as_numpy_datatype());
28+
Labels = labels;
29+
}
30+
}
31+
}

src/TensorFlowHub/MnistModelLoader.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Threading.Tasks;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using NumSharp;
6+
7+
namespace Tensorflow.Hub
8+
{
9+
public class MnistModelLoader : IModelLoader<MnistDataSet>
10+
{
11+
public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
12+
{
13+
throw new NotImplementedException();
14+
}
15+
}
16+
}

src/TensorFlowHub/ModelLoadSetting.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp;
5+
6+
namespace Tensorflow.Hub
7+
{
8+
public class ModelLoadSetting
9+
{
10+
public string TrainDir { get; set; }
11+
public bool OneHot { get; set; }
12+
public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT;
13+
public bool ReShape { get; set; }
14+
public int ValidationSize { get; set; } = 5000;
15+
public int? TrainSize { get; set; }
16+
public int? TestSize { get; set; }
17+
public string SourceUrl { get; set; }
18+
}
19+
}
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
<Project Sdk="Microsoft.NET.Sdk">
2-
32
<PropertyGroup>
3+
<AssemblyName>TensorFlow.Net.Hub</AssemblyName>
4+
<RootNamespace>Tensorflow.Hub</RootNamespace>
45
<TargetFramework>netstandard2.0</TargetFramework>
56
</PropertyGroup>
6-
7+
<ItemGroup>
8+
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
9+
</ItemGroup>
10+
<ItemGroup>
11+
<PackageReference Include="NumSharp" Version="0.10.4" />
12+
</ItemGroup>
713
</Project>

0 commit comments

Comments
 (0)