Skip to content
This repository was archived by the owner on Mar 11, 2021. It is now read-only.

Commit f194b0b

Browse files
committed
add avg_stones commands
1 parent 792b24b commit f194b0b

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

dual_net.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,13 @@ def model_fn(features, labels, mode, params):
284284
train_op = optimizer.minimize(combined_cost, global_step=global_step)
285285

286286
# Computations to be executed on CPU, outside of the main TPU queues.
287-
def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
288-
value_tensor, policy_cost, value_cost,
289-
l2_cost, combined_cost, step,
290-
est_mode=tf.estimator.ModeKeys.TRAIN):
287+
def eval_metrics_host_call_fn(
288+
features,
289+
policy_output, value_output,
290+
pi_tensor, value_tensor,
291+
policy_cost, value_cost,
292+
l2_cost, combined_cost,
293+
step, est_mode=tf.estimator.ModeKeys.TRAIN):
291294
policy_entropy = -tf.reduce_mean(tf.reduce_sum(
292295
policy_output * tf.log(policy_output), axis=1))
293296
# pi_tensor is one_hot when generated from sgfs (for supervised learning)
@@ -306,6 +309,8 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
306309

307310
value_cost_normalized = value_cost / params['value_cost_weight']
308311
avg_value_observed = tf.reduce_mean(value_tensor)
312+
avg_stones_black = tf.reduce_mean(tf.reduce_sum(features[:,:,:,1], [1,2]))
313+
avg_stones_white = tf.reduce_mean(tf.reduce_sum(features[:,:,:,0], [1,2]))
309314

310315
with tf.variable_scope('metrics'):
311316
metric_ops = {
@@ -315,13 +320,17 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
315320
'l2_cost': tf.metrics.mean(l2_cost),
316321
'policy_entropy': tf.metrics.mean(policy_entropy),
317322
'combined_cost': tf.metrics.mean(combined_cost),
318-
'avg_value_observed': tf.metrics.mean(avg_value_observed),
319323
'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
320324
'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
321325
'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
326+
'value_confidence': tf.metrics.mean(tf.abs(value_output)),
327+
328+
# Metrics about input data
322329
'policy_target_top_1_confidence': tf.metrics.mean(
323330
policy_target_top_1_confidence),
324-
'value_confidence': tf.metrics.mean(tf.abs(value_output)),
331+
'avg_value_observed': tf.metrics.mean(avg_value_observed),
332+
'avg_stones_black': tf.metrics.mean(avg_stones_black),
333+
'avg_stones_white': tf.metrics.mean(avg_stones_white),
325334
}
326335

327336
if est_mode == tf.estimator.ModeKeys.EVAL:
@@ -349,6 +358,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
349358
return summary.all_summary_ops() + [cond_reset_op]
350359

351360
metric_args = [
361+
features,
352362
policy_output,
353363
value_output,
354364
labels['pi_tensor'],

0 commit comments

Comments
 (0)