Skip to content

Commit ab65884

Browse files
authored
Merge pull request #89 from Genentech/fix-ism
fix bug in ism_predict on multiple sequences
2 parents 292aa93 + 23aa16f commit ab65884

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/grelu/interpret/score.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
def ISM_predict(
20-
seqs: Union[pd.DataFrame, np.ndarray, str],
20+
seqs: Union[pd.DataFrame, np.ndarray, str, List[str]],
2121
model: Callable,
2222
genome: Optional[str] = None,
2323
prediction_transform: Optional[Callable] = None,
@@ -87,11 +87,16 @@ def ISM_predict(
8787
)
8888
# B, L, 4, T, L
8989

90-
# Calculate log ratio w.r.t reference sequence
9190
if compare_func is not None:
91+
92+
# Slice the prediction corresponding to each reference sequence
9293
ref_bases = [BASE_TO_INDEX_HASH[seq[start_pos]] for seq in seqs]
93-
ref_pred = preds[:, [0], [ref_bases], :] # B, 1, 1, T, L
94-
preds = get_compare_func(compare_func, tensor=False)(preds, ref_pred)
94+
ref_preds = np.concatenate(
95+
[preds[None, None, None, i, 0, x] for i, x in enumerate(ref_bases)]
96+
) # B, L, 1, T, L
97+
98+
# Compare all predictions to the prediction for the corresponding reference sequence
99+
preds = get_compare_func(compare_func, tensor=False)(preds, ref_preds)
95100

96101
# Convert into a dataframe
97102
if return_df:

tests/test_interpret.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,41 @@ def test_marginalize_patterns():
9393

9494

9595
def test_ISM_predict():
96+
97+
# Single sequence
9698
seq = "AA"
9799
expected_preds = np.array([[4.0, 1.0, 2.0, 2.0], [4.0, 1.0, 2.0, 2.0]]).T
98100
preds = ISM_predict(seq, model, compare_func=None)
99101
assert np.allclose(preds.values, expected_preds)
100102
preds = ISM_predict(seq, model, compare_func="log2FC")
101103
assert np.allclose(preds.values, np.log2(expected_preds / 4))
102104

105+
# Multiple sequences
106+
seqs = ["AAA", "CCC"]
107+
expected_preds = np.expand_dims(
108+
np.array(
109+
[
110+
[
111+
[4.0, 2.0, 2.6666667, 2.6666667],
112+
[4.0, 2.0, 2.6666667, 2.6666667],
113+
[4.0, 2.0, 2.6666667, 2.6666667],
114+
],
115+
[
116+
[0.0, -2.0, -1.3333334, -1.3333334],
117+
[0.0, -2.0, -1.3333334, -1.3333334],
118+
[0.0, -2.0, -1.3333334, -1.3333334],
119+
],
120+
]
121+
),
122+
(3, 4),
123+
)
124+
preds = ISM_predict(seqs, model, compare_func=None, return_df=False)
125+
assert np.allclose(preds, expected_preds)
126+
preds = ISM_predict(seqs, model, compare_func="log2FC", return_df=False)
127+
assert np.allclose(
128+
preds, np.log2(np.stack([expected_preds[0] / 4, -expected_preds[1] / 2]))
129+
)
130+
103131

104132
def test_get_attributions():
105133
seq = generate_random_sequences(n=1, seq_len=50, seed=0, output_format="strings")[0]

0 commit comments

Comments
 (0)