Skip to content

Commit fcd2cd6

Browse files
committed
nn_ops.in_top_kv2
1 parent 59cbca5 commit fcd2cd6

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding,
134134
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
135135

136136
public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK")
137-
=> gen_ops.in_top_k(predictions, targets, k, name);
137+
=> nn_ops.in_top_k(predictions, targets, k, name);
138138

139139
public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
140140
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,27 @@ public static Tensor log_softmax(Tensor logits, string name = null)
244244
logits
245245
});
246246

247-
return _op.outputs[0];
247+
return _op.output;
248+
}
249+
250+
/// <summary>
251+
/// Says whether the targets are in the top `K` predictions.
252+
/// </summary>
253+
/// <param name="predictions"></param>
254+
/// <param name="targets"></param>
255+
/// <param name="k"></param>
256+
/// <param name="name"></param>
257+
/// <returns>A `Tensor` of type `bool`.</returns>
258+
public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null)
259+
{
260+
var _op = _op_def_lib._apply_op_helper("InTopKV2", name: name, args: new
261+
{
262+
predictions,
263+
targets,
264+
k
265+
});
266+
267+
return _op.output;
248268
}
249269

250270
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ private static Tensor _get_noise_shape(Tensor x, Tensor noise_shape)
111111
return noise_shape;
112112
}
113113

114+
public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null)
115+
{
116+
return tf_with(ops.name_scope(name, "in_top_k"), delegate
117+
{
118+
return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name);
119+
});
120+
}
121+
114122
public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null)
115123
{
116124
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);

0 commit comments

Comments
 (0)