Skip to content

Commit 4fc093e

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Fix typing for visualization.py (#1624)
Summary: Pull Request resolved: #1624 As title. Reviewed By: vivekmig Differential Revision: D77790233 fbshipit-source-id: 488f2bf8c63e0c3cd2310a23f301f88bf72d1631
1 parent f282901 commit 4fc093e

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

captum/attr/_utils/visualization.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44
import warnings
55
from enum import Enum
6-
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
Iterable,
12+
List,
13+
Optional,
14+
Sequence,
15+
Tuple,
16+
Union,
17+
)
718

819
import matplotlib
920

@@ -74,8 +85,7 @@ def _cumulative_sum_threshold(
7485
)
7586
sorted_vals = np.sort(values.flatten())
7687
cum_sums = np.cumsum(sorted_vals)
77-
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
78-
# pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
88+
threshold_id: int = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
7989
return sorted_vals[threshold_id]
8090

8191

@@ -959,7 +969,7 @@ def __init__(
959969
self.convergence_score: float = convergence_score
960970

961971

962-
def _get_color(attr: int) -> str:
972+
def _get_color(attr: float) -> str:
963973
# clip values to prevent CSS errors (Values should be from [-1,1])
964974
attr = max(-1, min(1, attr))
965975
if attr > 0:
@@ -973,8 +983,7 @@ def _get_color(attr: int) -> str:
973983
return "hsl({}, {}%, {}%)".format(hue, sat, lig)
974984

975985

976-
# pyre-fixme[2]: Parameter must be annotated.
977-
def format_classname(classname) -> str:
986+
def format_classname(classname: Union[str, int]) -> str:
978987
return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname)
979988

980989

@@ -984,19 +993,24 @@ def format_special_tokens(token: str) -> str:
984993
return token
985994

986995

987-
# pyre-fixme[2]: Parameter must be annotated.
988-
def format_tooltip(item, text) -> str:
996+
def format_tooltip(item: str, text: str) -> str:
989997
return '<div class="tooltip">{item}\
990998
<span class="tooltiptext">{text}</span>\
991999
</div>'.format(
9921000
item=item, text=text
9931001
)
9941002

9951003

996-
# pyre-fixme[2]: Parameter must be annotated.
997-
def format_word_importances(words, importances) -> str:
1004+
def format_word_importances(
1005+
words: Sequence[str],
1006+
importances: Union[Sequence[float], npt.NDArray[np.number], Tensor],
1007+
) -> str:
9981008
if importances is None or len(importances) == 0:
9991009
return "<td></td>"
1010+
if isinstance(importances, np.ndarray) or isinstance(importances, Tensor):
1011+
assert len(importances.shape) == 1, "Expected 1D array, got {}".format(
1012+
importances.shape
1013+
)
10001014
assert len(words) <= len(importances)
10011015
tags = ["<td>"]
10021016
for word, importance in zip(words, importances[: len(words)]):

0 commit comments

Comments
 (0)