@@ -451,6 +451,8 @@ def nodes_df(self):
451451 "ancestors_span" : child_right - child_left ,
452452 "child_left" : child_left , # FIXME add test for this
453453 "child_right" : child_right , # FIXME add test for this
454+ "child_left" : child_left , # FIXME add test for this
455+ "child_right" : child_right , # FIXME add test for this
454456 "is_sample" : is_sample ,
455457 }
456458 )
@@ -589,7 +591,7 @@ def calc_mutations_per_tree(self):
589591 mutations_per_tree [unique_values ] = counts
590592 return mutations_per_tree
591593
592- def compute_ancestor_spans_heatmap_data (self , win_x_size = 1_000_000 , win_y_size = 500 ):
594+ def compute_ancestor_spans_heatmap_data (self , num_x_bins , num_y_bins ):
593595 """
594596 Calculates the average ancestor span in a genomic-time window
595597 """
@@ -598,38 +600,38 @@ def compute_ancestor_spans_heatmap_data(self, win_x_size=1_000_000, win_y_size=5
598600 nodes_left = nodes_df .child_left
599601 nodes_right = nodes_df .child_right
600602 nodes_time = nodes_df .time
601- ancestors_span = nodes_df .ancestors_span
602603
603- num_x_wins = int (np .ceil (nodes_right .max () - nodes_left .min ()) / win_x_size )
604- num_y_wins = int (np .ceil (nodes_time .max () / win_y_size ))
605- heatmap_sums = np .zeros ((num_x_wins , num_y_wins ))
606- heatmap_counts = np .zeros ((num_x_wins , num_y_wins ))
604+ x_bins = np .linspace (nodes_left .min (), nodes_right .max (), num_x_bins + 1 )
605+ y_bins = np .linspace (0 , nodes_time .max (), num_y_bins + 1 )
606+ heatmap_counts = np .zeros ((num_x_bins , num_y_bins ))
607607
608- for u in range (len (nodes_left )):
609- x_start = int (
610- np .floor (nodes_left [u ] / win_x_size )
611- ) # map the node span to the x-axis bins it overlaps
612- x_end = int (np .floor (nodes_right [u ] / win_x_size ))
613- y = max (0 , int (np .floor (nodes_time [u ] / win_y_size )) - 1 )
614- heatmap_sums [x_start :x_end , y ] += min (ancestors_span [u ], win_x_size )
615- heatmap_counts [x_start :x_end , y ] += 1
616-
617- avg_spans = heatmap_sums / heatmap_counts
618- indices = np .indices ((num_x_wins , num_y_wins ))
619- x_coords = indices [0 ] * win_x_size
620- y_coords = indices [1 ] * win_y_size
608+ x_starts = np .digitize (nodes_left , x_bins , right = True )
609+ x_ends = np .digitize (nodes_right , x_bins , right = True )
610+ y_starts = np .digitize (nodes_time , y_bins , right = True )
621611
612+ for u in range (len (nodes_left )):
613+ x_start = max (0 , x_starts [u ] - 1 )
614+ x_end = max (0 , x_ends [u ] - 1 )
615+ y_bin = max (0 , y_starts [u ] - 1 )
616+ heatmap_counts [x_start : x_end + 1 , y_bin ] += 1
617+
618+ x_coords = np .repeat (x_bins [:- 1 ], num_y_bins )
619+ y_coords = np .tile (y_bins [:- 1 ], num_x_bins )
620+ overlapping_node_count = heatmap_counts .flatten ()
621+ overlapping_node_count [overlapping_node_count == 0 ] = 1
622+ # FIXME - better way to avoid log 0 above?
622623 df = pd .DataFrame (
623624 {
624- "genomic_position " : x_coords .flatten (),
625+ "position " : x_coords .flatten (),
625626 "time" : y_coords .flatten (),
626- "average_ancestor_span" : avg_spans .flatten (),
627+ "overlapping_node_count_log10" : np .log10 (overlapping_node_count ),
628+ "overlapping_node_count" : overlapping_node_count ,
627629 }
628630 )
629631 return df .astype (
630632 {
631- "genomic_position " : "int" ,
633+ "position " : "int" ,
632634 "time" : "int" ,
633- "average_ancestor_span " : "float64 " ,
635+ "overlapping_node_count " : "int " ,
634636 }
635637 )
0 commit comments