Skip to content

Commit b3cd413

Browse files
committed
fix Strides default value for Pooling2D.
1 parent fd32b84 commit b3cd413

File tree

13 files changed

+214
-51
lines changed

13 files changed

+214
-51
lines changed

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public static (Dictionary<string, IVariableV1>, ITensorOrOperation[]) import_sco
150150
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
151151
scope: scope_to_prepend_to_names);
152152
var var_list = new Dictionary<string, IVariableV1>();
153-
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
153+
// variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
154154

155155
return (var_list, imported_return_elements);
156156
}

src/TensorFlowNET.Core/Keras/Engine/Model.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using NumSharp;
22
using System;
33
using Tensorflow.Keras.ArgsDefinition;
4+
using Tensorflow.Keras.Losses;
45
using Tensorflow.Keras.Optimizers;
56

67
namespace Tensorflow.Keras.Engine
@@ -42,6 +43,11 @@ public void compile(string optimizerName, string lossName)
4243
// Prepare list of loss functions, same size of model outputs.
4344
}
4445

46+
public void compile(string optimizerName, ILossFunc lossName)
47+
{
48+
throw new NotImplementedException("");
49+
}
50+
4551
/// <summary>
4652
/// Generates output predictions for the input samples.
4753
/// </summary>

src/TensorFlowNET.Core/Keras/KerasApi.cs

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Tensorflow.Keras.Datasets;
88
using Tensorflow.Keras.Engine;
99
using Tensorflow.Keras.Layers;
10+
using Tensorflow.Keras.Losses;
1011
using static Tensorflow.Binding;
1112

1213
namespace Tensorflow
@@ -16,6 +17,7 @@ public class KerasApi
1617
public KerasDataset datasets { get; } = new KerasDataset();
1718
public Initializers initializers { get; } = new Initializers();
1819
public LayersApi layers { get; } = new LayersApi();
20+
public LossesApi losses { get; } = new LossesApi();
1921
public Activations activations { get; } = new Activations();
2022
public Preprocessing preprocessing { get; } = new Preprocessing();
2123
public BackendImpl backend { get; } = new BackendImpl();
@@ -69,52 +71,5 @@ public Tensor Input(TensorShape shape = null,
6971

7072
return layer.InboundNodes[0].Outputs;
7173
}
72-
73-
public class LayersApi
74-
{
75-
public Rescaling Rescaling(float scale,
76-
float offset = 0,
77-
TensorShape input_shape = null)
78-
=> new Rescaling(new RescalingArgs
79-
{
80-
Scale = scale,
81-
Offset = offset,
82-
InputShape = input_shape
83-
});
84-
85-
public Dense Dense(int units,
86-
Activation activation = null,
87-
TensorShape input_shape = null)
88-
=> new Dense(new DenseArgs
89-
{
90-
Units = units,
91-
Activation = activation ?? tf.keras.activations.Linear,
92-
InputShape = input_shape
93-
});
94-
95-
/// <summary>
96-
/// Turns positive integers (indexes) into dense vectors of fixed size.
97-
/// </summary>
98-
/// <param name="input_dim"></param>
99-
/// <param name="output_dim"></param>
100-
/// <param name="embeddings_initializer"></param>
101-
/// <param name="mask_zero"></param>
102-
/// <returns></returns>
103-
public Embedding Embedding(int input_dim,
104-
int output_dim,
105-
IInitializer embeddings_initializer = null,
106-
bool mask_zero = false,
107-
TensorShape input_shape = null,
108-
int input_length = -1)
109-
=> new Embedding(new EmbeddingArgs
110-
{
111-
InputDim = input_dim,
112-
OutputDim = output_dim,
113-
MaskZero = mask_zero,
114-
InputShape = input_shape ?? input_length,
115-
InputLength = input_length,
116-
EmbeddingsInitializer = embeddings_initializer
117-
});
118-
}
11974
}
12075
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Layers
9+
{
10+
public class LayersApi
11+
{
12+
public Conv2D Conv2D(int filters,
13+
TensorShape kernel_size = null,
14+
string padding = "valid",
15+
string activation = "relu")
16+
=> new Conv2D(new Conv2DArgs
17+
{
18+
Filters = filters,
19+
KernelSize = kernel_size,
20+
Padding = padding,
21+
Activation = GetActivationByName(activation)
22+
});
23+
24+
25+
public Dense Dense(int units,
26+
string activation = "linear",
27+
TensorShape input_shape = null)
28+
=> new Dense(new DenseArgs
29+
{
30+
Units = units,
31+
Activation = GetActivationByName(activation),
32+
InputShape = input_shape
33+
});
34+
35+
/// <summary>
36+
/// Turns positive integers (indexes) into dense vectors of fixed size.
37+
/// </summary>
38+
/// <param name="input_dim"></param>
39+
/// <param name="output_dim"></param>
40+
/// <param name="embeddings_initializer"></param>
41+
/// <param name="mask_zero"></param>
42+
/// <returns></returns>
43+
public Embedding Embedding(int input_dim,
44+
int output_dim,
45+
IInitializer embeddings_initializer = null,
46+
bool mask_zero = false,
47+
TensorShape input_shape = null,
48+
int input_length = -1)
49+
=> new Embedding(new EmbeddingArgs
50+
{
51+
InputDim = input_dim,
52+
OutputDim = output_dim,
53+
MaskZero = mask_zero,
54+
InputShape = input_shape ?? input_length,
55+
InputLength = input_length,
56+
EmbeddingsInitializer = embeddings_initializer
57+
});
58+
59+
public Flatten Flatten(string data_format = null)
60+
=> new Flatten(new FlattenArgs
61+
{
62+
DataFormat = data_format
63+
});
64+
65+
public MaxPooling2D MaxPooling2D(TensorShape pool_size = null,
66+
TensorShape strides = null,
67+
string padding = "valid")
68+
=> new MaxPooling2D(new MaxPooling2DArgs
69+
{
70+
PoolSize = pool_size ?? (2, 2),
71+
Strides = strides,
72+
Padding = padding
73+
});
74+
75+
public Rescaling Rescaling(float scale,
76+
float offset = 0,
77+
TensorShape input_shape = null)
78+
=> new Rescaling(new RescalingArgs
79+
{
80+
Scale = scale,
81+
Offset = offset,
82+
InputShape = input_shape
83+
});
84+
85+
Activation GetActivationByName(string name)
86+
=> name switch
87+
{
88+
"linear" => tf.keras.activations.Linear,
89+
"relu" => tf.keras.activations.Relu,
90+
"sigmoid" => tf.keras.activations.Sigmoid,
91+
"tanh" => tf.keras.activations.Tanh,
92+
_ => tf.keras.activations.Linear
93+
};
94+
}
95+
}

src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public Pooling2D(Pooling2DArgs args)
3030
{
3131
this.args = args;
3232
args.PoolSize = conv_utils.normalize_tuple(args.PoolSize, 2, "pool_size");
33-
args.Strides = conv_utils.normalize_tuple(args.Strides, 2, "strides");
33+
args.Strides = conv_utils.normalize_tuple(args.Strides ?? args.PoolSize, 2, "strides");
3434
args.Padding = conv_utils.normalize_padding(args.Padding);
3535
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
3636
input_spec = new InputSpec(ndim: 4);

src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ public Rescaling(RescalingArgs args) : base(args)
2323
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
2424
{
2525
scale = math_ops.cast(args.Scale, args.DType);
26-
throw new NotImplementedException("");
26+
offset = math_ops.cast(args.Offset, args.DType);
27+
return math_ops.cast(inputs, args.DType) * scale + offset;
2728
}
2829
}
2930
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Losses
6+
{
7+
public interface ILossFunc
8+
{
9+
}
10+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Losses
6+
{
7+
/// <summary>
8+
/// Loss base class.
9+
/// </summary>
10+
public abstract class Loss
11+
{
12+
protected string reduction;
13+
protected string name;
14+
bool _allow_sum_over_batch_size;
15+
string _name_scope;
16+
17+
public Loss(string reduction = ReductionV2.AUTO, string name = null)
18+
{
19+
this.reduction = reduction;
20+
this.name = name;
21+
_allow_sum_over_batch_size = false;
22+
}
23+
24+
void _set_name_scope()
25+
{
26+
_name_scope = name;
27+
}
28+
}
29+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Losses
6+
{
7+
public class LossFunctionWrapper : Loss
8+
{
9+
Action fn;
10+
11+
public LossFunctionWrapper(Action fn,
12+
string reduction = ReductionV2.AUTO,
13+
string name = null)
14+
: base(reduction: reduction,
15+
name: name)
16+
{
17+
this.fn = fn;
18+
}
19+
}
20+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Losses
6+
{
7+
public class LossesApi
8+
{
9+
public ILossFunc SparseCategoricalCrossentropy(bool from_logits = false)
10+
=> new SparseCategoricalCrossentropy(from_logits: from_logits);
11+
}
12+
}

0 commit comments

Comments
 (0)