3
3
# pyre-strict
4
4
import warnings
5
5
from enum import Enum
6
- from typing import Any , Callable , cast , Dict , Iterable , List , Optional , Tuple , Union
6
+ from typing import (
7
+ Any ,
8
+ Callable ,
9
+ cast ,
10
+ Dict ,
11
+ Iterable ,
12
+ List ,
13
+ Optional ,
14
+ Sequence ,
15
+ Tuple ,
16
+ Union ,
17
+ )
7
18
8
19
import matplotlib
9
20
@@ -74,8 +85,7 @@ def _cumulative_sum_threshold(
74
85
)
75
86
sorted_vals = np .sort (values .flatten ())
76
87
cum_sums = np .cumsum (sorted_vals )
77
- threshold_id = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
78
- # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
88
+ threshold_id : int = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
79
89
return sorted_vals [threshold_id ]
80
90
81
91
@@ -959,7 +969,7 @@ def __init__(
959
969
self .convergence_score : float = convergence_score
960
970
961
971
962
- def _get_color (attr : int ) -> str :
972
+ def _get_color (attr : float ) -> str :
963
973
# clip values to prevent CSS errors (Values should be from [-1,1])
964
974
attr = max (- 1 , min (1 , attr ))
965
975
if attr > 0 :
@@ -973,8 +983,7 @@ def _get_color(attr: int) -> str:
973
983
return "hsl({}, {}%, {}%)" .format (hue , sat , lig )
974
984
975
985
976
- # pyre-fixme[2]: Parameter must be annotated.
977
- def format_classname (classname ) -> str :
986
+ def format_classname (classname : Union [str , int ]) -> str :
978
987
return '<td><text style="padding-right:2em"><b>{}</b></text></td>' .format (classname )
979
988
980
989
@@ -984,19 +993,24 @@ def format_special_tokens(token: str) -> str:
984
993
return token
985
994
986
995
987
- # pyre-fixme[2]: Parameter must be annotated.
988
- def format_tooltip (item , text ) -> str :
996
+ def format_tooltip (item : str , text : str ) -> str :
989
997
return '<div class="tooltip">{item}\
990
998
<span class="tooltiptext">{text}</span>\
991
999
</div>' .format (
992
1000
item = item , text = text
993
1001
)
994
1002
995
1003
996
- # pyre-fixme[2]: Parameter must be annotated.
997
- def format_word_importances (words , importances ) -> str :
1004
+ def format_word_importances (
1005
+ words : Sequence [str ],
1006
+ importances : Union [Sequence [float ], npt .NDArray [np .number ], Tensor ],
1007
+ ) -> str :
998
1008
if importances is None or len (importances ) == 0 :
999
1009
return "<td></td>"
1010
+ if isinstance (importances , np .ndarray ) or isinstance (importances , Tensor ):
1011
+ assert len (importances .shape ) == 1 , "Expected 1D array, got {}" .format (
1012
+ importances .shape
1013
+ )
1000
1014
assert len (words ) <= len (importances )
1001
1015
tags = ["<td>" ]
1002
1016
for word , importance in zip (words , importances [: len (words )]):
0 commit comments