Skip to content

Commit 831dea5

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Allow more variable input to LLMAttributionResult (#1627)
Summary: 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863
1 parent 4fc093e commit 831dea5

File tree

1 file changed

+72
-6
lines changed

1 file changed

+72
-6
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import shorten
8+
from textwrap import dedent, shorten
99
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1010

1111
import matplotlib.colors as mcolors
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
15+
import numpy.typing as npt
1516

1617
import torch
1718
from captum._utils.typing import TokenizerLike
@@ -51,10 +52,70 @@ class LLMAttributionResult:
5152
It also provides utilities to help present and plot the result in different forms.
5253
"""
5354

54-
seq_attr: Tensor
55-
token_attr: Optional[Tensor]
5655
input_tokens: List[str]
5756
output_tokens: List[str]
57+
# pyre-ignore[13]: initialized via a property setter
58+
_seq_attr: Tensor
59+
_token_attr: Optional[Tensor] = None
60+
61+
def __init__(
62+
self,
63+
seq_attr: npt.ArrayLike,
64+
token_attr: Optional[npt.ArrayLike],
65+
input_tokens: List[str],
66+
output_tokens: List[str],
67+
) -> None:
68+
self.input_tokens = input_tokens
69+
self.output_tokens = output_tokens
70+
self.seq_attr = seq_attr
71+
self.token_attr = token_attr
72+
73+
@property
74+
def seq_attr(self) -> Tensor:
75+
return self._seq_attr
76+
77+
@seq_attr.setter
78+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
79+
if isinstance(seq_attr, Tensor):
80+
self._seq_attr = seq_attr
81+
else:
82+
self._seq_attr = torch.tensor(seq_attr)
83+
# IDEA: in the future we might want to support higher dim seq_attr
84+
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
85+
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
86+
87+
assert (
88+
len(self.input_tokens) == self._seq_attr.shape[0]
89+
), "seq_attr and input_tokens must have the same length"
90+
91+
@property
92+
def token_attr(self) -> Optional[Tensor]:
93+
return self._token_attr
94+
95+
@token_attr.setter
96+
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
97+
if isinstance(token_attr, Tensor):
98+
self._token_attr = token_attr
99+
elif token_attr is None:
100+
# can't combine with previous clause, linter unhappy ¯\_(ツ)_/¯
101+
self._token_attr = None
102+
else:
103+
self._token_attr = torch.tensor(token_attr)
104+
# IDEA: in the future we might want to support higher dim seq_attr
105+
if self._token_attr is not None:
106+
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
107+
108+
if self._token_attr is not None:
109+
assert self._token_attr.shape == (
110+
len(self.output_tokens),
111+
len(self.input_tokens),
112+
), dedent(
113+
f"""\
114+
Expect token_attr to have shape
115+
{len(self.output_tokens), len(self.input_tokens)},
116+
got {self._token_attr.shape}
117+
"""
118+
)
58119

59120
@property
60121
def seq_attr_dict(self) -> Dict[str, float]:
@@ -124,10 +185,14 @@ def plot_token_attr(
124185

125186
# Show all ticks and label them with the respective list entries.
126187
shortened_tokens = [
127-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
188+
shorten(repr(t)[1:-1], width=50, placeholder="...")
189+
for t in self.input_tokens
128190
]
129191
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
130-
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
192+
ax.set_yticks(
193+
np.arange(data.shape[0]),
194+
labels=[repr(token)[1:-1] for token in self.output_tokens],
195+
)
131196

132197
# Let the horizontal axes labeling appear on top.
133198
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
@@ -175,7 +240,8 @@ def plot_seq_attr(
175240
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
176241

177242
shortened_tokens = [
178-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
243+
shorten(repr(t)[1:-1], width=50, placeholder="...")
244+
for t in self.input_tokens
179245
]
180246
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
181247

0 commit comments

Comments
 (0)