Skip to content

Commit e5dc65a

Browse files
committed
Support SigmoidFocalCrossEntropy, better for imbalanced multi-class task.
1 parent 55cc4d0 commit e5dc65a

File tree

4 files changed

+88
-0
lines changed

4 files changed

+88
-0
lines changed

src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,19 @@ ILossFunc Huber(string reduction = null,
3838

3939
ILossFunc LogCosh(string reduction = null,
4040
string name = null);
41+
42+
/// <summary>
43+
/// Implements the focal loss function.
44+
/// </summary>
45+
/// <param name="from_logits"></param>
46+
/// <param name="alpha"></param>
47+
/// <param name="gamma"></param>
48+
/// <param name="reduction"></param>
49+
/// <param name="name"></param>
50+
/// <returns></returns>
51+
ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
52+
float alpha = 0.25f,
53+
float gamma = 2.0f,
54+
string reduction = "none",
55+
string name = "sigmoid_focal_crossentropy");
4156
}

src/TensorFlowNET.Keras/Losses/LossesApi.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,16 @@ public ILossFunc Huber(string reduction = null, string name = null, Tensor delta
3737

3838
public ILossFunc LogCosh(string reduction = null, string name = null)
3939
=> new LogCosh(reduction: reduction, name: name);
40+
41+
public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
42+
float alpha = 0.25F,
43+
float gamma = 2,
44+
string reduction = "none",
45+
string name = "sigmoid_focal_crossentropy")
46+
=> new SigmoidFocalCrossEntropy(from_logits: from_logits,
47+
alpha: alpha,
48+
gamma: gamma,
49+
reduction: reduction,
50+
name: name);
4051
}
4152
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using static HDF.PInvoke.H5L.info_t;
2+
3+
namespace Tensorflow.Keras.Losses;
4+
5+
public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc
6+
{
7+
float _alpha;
8+
float _gamma;
9+
10+
public SigmoidFocalCrossEntropy(bool from_logits = false,
11+
float alpha = 0.25f,
12+
float gamma = 2.0f,
13+
string reduction = "none",
14+
string name = "sigmoid_focal_crossentropy") :
15+
base(reduction: reduction,
16+
name: name,
17+
from_logits: from_logits)
18+
{
19+
_alpha = alpha;
20+
_gamma = gamma;
21+
}
22+
23+
24+
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
25+
{
26+
y_true = tf.cast(y_true, dtype: y_pred.dtype);
27+
var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
28+
var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred;
29+
30+
var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob));
31+
Tensor alpha_factor = constant_op.constant(1.0f);
32+
Tensor modulating_factor = constant_op.constant(1.0f);
33+
34+
if(_alpha > 0)
35+
{
36+
var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype);
37+
alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha);
38+
}
39+
40+
if (_gamma > 0)
41+
{
42+
var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype);
43+
modulating_factor = tf.pow(1f - p_t, gamma);
44+
}
45+
46+
return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1);
47+
}
48+
}

test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Text;
66
using System.Threading.Tasks;
77
using Tensorflow;
8+
using Tensorflow.NumPy;
89
using TensorFlowNET.Keras.UnitTest;
910
using static Tensorflow.Binding;
1011
using static Tensorflow.KerasApi;
@@ -48,4 +49,17 @@ public void BinaryCrossentropy()
4849
loss = bce.Call(y_true, y_pred);
4950
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy());
5051
}
52+
53+
/// <summary>
54+
/// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy
55+
/// </summary>
56+
[TestMethod]
57+
public void SigmoidFocalCrossEntropy()
58+
{
59+
var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 }));
60+
var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f }));
61+
var bce = tf.keras.losses.SigmoidFocalCrossEntropy();
62+
var loss = bce.Call(y_true, y_pred);
63+
Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy());
64+
}
5165
}

0 commit comments

Comments
 (0)