Skip to content

Commit c990362

Browse files
committed
Plot the task marginal importance for GTE
1 parent 6bd1d71 commit c990362

File tree

1 file changed

+52
-6
lines changed

1 file changed

+52
-6
lines changed

kaggle_environments/envs/werewolf/eval/metrics.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,12 @@ def _run_gte_evaluation(self, num_samples: int):
484484
agents = sorted(list(self.metrics.keys()))
485485
rnd = np.random.default_rng(42)
486486

487-
ratings, joints, _, contributions, self.gte_game = self._bootstrap_stats(
487+
ratings, joints, marginals, contributions, self.gte_game = self._bootstrap_stats(
488488
rnd, self.games, agents, self.gte_tasks, num_samples=num_samples
489489
)
490490
self.gte_ratings = ratings
491491
self.gte_joint = joints
492+
self.gte_marginals = marginals
492493
self.gte_contributions_raw = contributions
493494

494495
ratings_mean, ratings_std = self.gte_ratings
@@ -573,9 +574,14 @@ def _bootstrap_stats(self, rnd, games, agents, tasks, num_samples=10):
573574
ratings_std = [np.std(r, axis=0) for r in zip(*ratings)]
574575
joints_mean = np.mean(joints, axis=0)
575576
joints_std = np.std(joints, axis=0)
577+
578+
marginals_by_dim = list(zip(*marginals))
579+
marginals_mean = [np.mean(m, axis=0) for m in marginals_by_dim]
580+
marginals_std = [np.std(m, axis=0) for m in marginals_by_dim]
581+
576582
contributions_mean = np.mean(contributions, axis=0)
577583
contributions_std = np.std(contributions, axis=0)
578-
return (ratings_mean, ratings_std), (joints_mean, joints_std), None, (contributions_mean, contributions_std), \
584+
return (ratings_mean, ratings_std), (joints_mean, joints_std), (marginals_mean, marginals_std), (contributions_mean, contributions_std), \
579585
games[0]
580586

581587
def plot_gte_evaluation(self, top_k: int = 100, output_path="gte_evaluation.html"):
@@ -614,7 +620,7 @@ def plot_gte_evaluation(self, top_k: int = 100, output_path="gte_evaluation.html
614620

615621
# The game object has 'metric' and 'agent' as players.
616622
# rating_player=1 is 'agent', contrib_player=0 is 'metric'.
617-
chart = _gte_rating_contribution_chart(
623+
rating_chart = _gte_rating_contribution_chart(
618624
game=self.gte_game,
619625
joint=joint_avg,
620626
contributions=contribution_avg,
@@ -624,9 +630,48 @@ def plot_gte_evaluation(self, top_k: int = 100, output_path="gte_evaluation.html
624630
contrib_player=0,
625631
top_k=top_k)
626632

633+
# --- Chart 2: Task Importance (Marginal Probability) ---
634+
# marginals[0] is for Player 0 (Metric/Task)
635+
# self.gte_marginals is (marginals_mean, marginals_std)
636+
# marginals_mean is [mean_p0, mean_p1], marginals_std is [std_p0, std_p1]
637+
task_marginals_mean = self.gte_marginals[0][0]
638+
task_marginals_std = self.gte_marginals[1][0]
639+
640+
task_importance_df = pd.DataFrame({
641+
'metric': tasks,
642+
'importance': task_marginals_mean,
643+
'std': task_marginals_std
644+
})
645+
# 95% CI
646+
task_importance_df['ci'] = task_importance_df['std'] * 1.96
647+
task_importance_df['min_val'] = task_importance_df['importance'] - task_importance_df['ci']
648+
task_importance_df['max_val'] = task_importance_df['importance'] + task_importance_df['ci']
649+
650+
base = alt.Chart(task_importance_df).encode(
651+
y=alt.Y('metric:N', sort='-x', title="Task"),
652+
x=alt.X('importance:Q', title="Marginal Probability (Importance)"),
653+
)
654+
655+
bars = base.mark_bar().encode(
656+
tooltip=['metric', 'importance', 'std']
657+
)
658+
659+
error_bars = alt.Chart(task_importance_df).mark_rule(color='black').encode(
660+
y=alt.Y('metric:N', sort='-x'),
661+
x=alt.X('min_val:Q'),
662+
x2=alt.X2('max_val:Q'),
663+
tooltip=['metric', 'importance', 'std']
664+
)
665+
666+
importance_chart = (bars + error_bars).properties(
667+
title="Task Importance (Marginal Probability in Equilibrium)"
668+
)
669+
670+
final_chart = alt.vconcat(rating_chart, importance_chart).resolve_scale(color='independent')
671+
627672
if output_path:
628-
chart.save(output_path)
629-
return chart
673+
final_chart.save(output_path)
674+
return final_chart
630675

631676
def print_results(self):
632677
"""Prints a formatted summary of the evaluation results."""
@@ -736,7 +781,8 @@ def plot_metrics(self, output_path="metrics.html"):
736781

737782
error_bars = base.mark_errorbar(extent='ci').encode(
738783
y='value:Q',
739-
yError='std:Q'
784+
yError='std:Q',
785+
color=alt.value('black')
740786
)
741787

742788
chart = (bars + error_bars).properties(

0 commit comments

Comments
 (0)