@@ -484,11 +484,12 @@ def _run_gte_evaluation(self, num_samples: int):
484484 agents = sorted (list (self .metrics .keys ()))
485485 rnd = np .random .default_rng (42 )
486486
487- ratings , joints , _ , contributions , self .gte_game = self ._bootstrap_stats (
487+ ratings , joints , marginals , contributions , self .gte_game = self ._bootstrap_stats (
488488 rnd , self .games , agents , self .gte_tasks , num_samples = num_samples
489489 )
490490 self .gte_ratings = ratings
491491 self .gte_joint = joints
492+ self .gte_marginals = marginals
492493 self .gte_contributions_raw = contributions
493494
494495 ratings_mean , ratings_std = self .gte_ratings
@@ -573,9 +574,14 @@ def _bootstrap_stats(self, rnd, games, agents, tasks, num_samples=10):
573574 ratings_std = [np .std (r , axis = 0 ) for r in zip (* ratings )]
574575 joints_mean = np .mean (joints , axis = 0 )
575576 joints_std = np .std (joints , axis = 0 )
577+
578+ marginals_by_dim = list (zip (* marginals ))
579+ marginals_mean = [np .mean (m , axis = 0 ) for m in marginals_by_dim ]
580+ marginals_std = [np .std (m , axis = 0 ) for m in marginals_by_dim ]
581+
576582 contributions_mean = np .mean (contributions , axis = 0 )
577583 contributions_std = np .std (contributions , axis = 0 )
578- return (ratings_mean , ratings_std ), (joints_mean , joints_std ), None , (contributions_mean , contributions_std ), \
584+ return (ratings_mean , ratings_std ), (joints_mean , joints_std ), ( marginals_mean , marginals_std ) , (contributions_mean , contributions_std ), \
579585 games [0 ]
580586
581587 def plot_gte_evaluation (self , top_k : int = 100 , output_path = "gte_evaluation.html" ):
@@ -614,7 +620,7 @@ def plot_gte_evaluation(self, top_k: int = 100, output_path="gte_evaluation.html
614620
615621 # The game object has 'metric' and 'agent' as players.
616622 # rating_player=1 is 'agent', contrib_player=0 is 'metric'.
617- chart = _gte_rating_contribution_chart (
623+ rating_chart = _gte_rating_contribution_chart (
618624 game = self .gte_game ,
619625 joint = joint_avg ,
620626 contributions = contribution_avg ,
@@ -624,9 +630,48 @@ def plot_gte_evaluation(self, top_k: int = 100, output_path="gte_evaluation.html
624630 contrib_player = 0 ,
625631 top_k = top_k )
626632
633+ # --- Chart 2: Task Importance (Marginal Probability) ---
634+ # marginals[0] is for Player 0 (Metric/Task)
635+ # self.gte_marginals is (marginals_mean, marginals_std)
636+ # marginals_mean is [mean_p0, mean_p1], marginals_std is [std_p0, std_p1]
637+ task_marginals_mean = self .gte_marginals [0 ][0 ]
638+ task_marginals_std = self .gte_marginals [1 ][0 ]
639+
640+ task_importance_df = pd .DataFrame ({
641+ 'metric' : tasks ,
642+ 'importance' : task_marginals_mean ,
643+ 'std' : task_marginals_std
644+ })
645+ # 95% CI
646+ task_importance_df ['ci' ] = task_importance_df ['std' ] * 1.96
647+ task_importance_df ['min_val' ] = task_importance_df ['importance' ] - task_importance_df ['ci' ]
648+ task_importance_df ['max_val' ] = task_importance_df ['importance' ] + task_importance_df ['ci' ]
649+
650+ base = alt .Chart (task_importance_df ).encode (
651+ y = alt .Y ('metric:N' , sort = '-x' , title = "Task" ),
652+ x = alt .X ('importance:Q' , title = "Marginal Probability (Importance)" ),
653+ )
654+
655+ bars = base .mark_bar ().encode (
656+ tooltip = ['metric' , 'importance' , 'std' ]
657+ )
658+
659+ error_bars = alt .Chart (task_importance_df ).mark_rule (color = 'black' ).encode (
660+ y = alt .Y ('metric:N' , sort = '-x' ),
661+ x = alt .X ('min_val:Q' ),
662+ x2 = alt .X2 ('max_val:Q' ),
663+ tooltip = ['metric' , 'importance' , 'std' ]
664+ )
665+
666+ importance_chart = (bars + error_bars ).properties (
667+ title = "Task Importance (Marginal Probability in Equilibrium)"
668+ )
669+
670+ final_chart = alt .vconcat (rating_chart , importance_chart ).resolve_scale (color = 'independent' )
671+
627672 if output_path :
628- chart .save (output_path )
629- return chart
673+ final_chart .save (output_path )
674+ return final_chart
630675
631676 def print_results (self ):
632677 """Prints a formatted summary of the evaluation results."""
@@ -736,7 +781,8 @@ def plot_metrics(self, output_path="metrics.html"):
736781
737782 error_bars = base .mark_errorbar (extent = 'ci' ).encode (
738783 y = 'value:Q' ,
739- yError = 'std:Q'
784+ yError = 'std:Q' ,
785+ color = alt .value ('black' )
740786 )
741787
742788 chart = (bars + error_bars ).properties (
0 commit comments