|
5 | 5 | from abc import ABC
|
6 | 6 | from copy import copy
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from textwrap import shorten |
| 8 | +from textwrap import dedent, shorten |
9 | 9 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
|
10 | 10 |
|
11 | 11 | import matplotlib.colors as mcolors
|
12 | 12 |
|
13 | 13 | import matplotlib.pyplot as plt
|
14 | 14 | import numpy as np
|
| 15 | +import numpy.typing as npt |
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 | from captum._utils.typing import TokenizerLike
|
@@ -51,10 +52,70 @@ class LLMAttributionResult:
|
51 | 52 | It also provides utilities to help present and plot the result in different forms.
|
52 | 53 | """
|
53 | 54 |
|
54 |
| - seq_attr: Tensor |
55 |
| - token_attr: Optional[Tensor] |
56 | 55 | input_tokens: List[str]
|
57 | 56 | output_tokens: List[str]
|
| 57 | + # pyre-ignore[13]: initialized via a property setter |
| 58 | + _seq_attr: Tensor |
| 59 | + _token_attr: Optional[Tensor] = None |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + seq_attr: npt.ArrayLike, |
| 64 | + token_attr: Optional[npt.ArrayLike], |
| 65 | + input_tokens: List[str], |
| 66 | + output_tokens: List[str], |
| 67 | + ) -> None: |
| 68 | + self.input_tokens = input_tokens |
| 69 | + self.output_tokens = output_tokens |
| 70 | + self.seq_attr = seq_attr |
| 71 | + self.token_attr = token_attr |
| 72 | + |
| 73 | + @property |
| 74 | + def seq_attr(self) -> Tensor: |
| 75 | + return self._seq_attr |
| 76 | + |
| 77 | + @seq_attr.setter |
| 78 | + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: |
| 79 | + if isinstance(seq_attr, Tensor): |
| 80 | + self._seq_attr = seq_attr |
| 81 | + else: |
| 82 | + self._seq_attr = torch.tensor(seq_attr) |
| 83 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 84 | + # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) |
| 85 | + assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" |
| 86 | + |
| 87 | + assert ( |
| 88 | + len(self.input_tokens) == self._seq_attr.shape[0] |
| 89 | + ), "seq_attr and input_tokens must have the same length" |
| 90 | + |
| 91 | + @property |
| 92 | + def token_attr(self) -> Optional[Tensor]: |
| 93 | + return self._token_attr |
| 94 | + |
| 95 | + @token_attr.setter |
| 96 | + def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: |
| 97 | + if isinstance(token_attr, Tensor): |
| 98 | + self._token_attr = token_attr |
| 99 | + elif token_attr is None: |
| 100 | + # can't combine with previous clause, linter unhappy ¯\_(ツ)_/¯ |
| 101 | + self._token_attr = None |
| 102 | + else: |
| 103 | + self._token_attr = torch.tensor(token_attr) |
| 104 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 105 | + if self._token_attr is not None: |
| 106 | + assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" |
| 107 | + |
| 108 | + if self._token_attr is not None: |
| 109 | + assert self._token_attr.shape == ( |
| 110 | + len(self.output_tokens), |
| 111 | + len(self.input_tokens), |
| 112 | + ), dedent( |
| 113 | + f"""\ |
| 114 | + Expect token_attr to have shape |
| 115 | + {len(self.output_tokens), len(self.input_tokens)}, |
| 116 | + got {self._token_attr.shape} |
| 117 | + """ |
| 118 | + ) |
58 | 119 |
|
59 | 120 | @property
|
60 | 121 | def seq_attr_dict(self) -> Dict[str, float]:
|
@@ -124,10 +185,14 @@ def plot_token_attr(
|
124 | 185 |
|
125 | 186 | # Show all ticks and label them with the respective list entries.
|
126 | 187 | shortened_tokens = [
|
127 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 188 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 189 | + for t in self.input_tokens |
128 | 190 | ]
|
129 | 191 | ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
|
130 |
| - ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) |
| 192 | + ax.set_yticks( |
| 193 | + np.arange(data.shape[0]), |
| 194 | + labels=[repr(token)[1:-1] for token in self.output_tokens], |
| 195 | + ) |
131 | 196 |
|
132 | 197 | # Let the horizontal axes labeling appear on top.
|
133 | 198 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
|
@@ -175,7 +240,8 @@ def plot_seq_attr(
|
175 | 240 | fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
|
176 | 241 |
|
177 | 242 | shortened_tokens = [
|
178 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 243 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 244 | + for t in self.input_tokens |
179 | 245 | ]
|
180 | 246 | ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
|
181 | 247 |
|
|
0 commit comments