@@ -93,13 +93,41 @@ def test_marginalize_patterns():
9393
9494
9595def test_ISM_predict ():
96+
97+ # Single sequence
9698 seq = "AA"
9799 expected_preds = np .array ([[4.0 , 1.0 , 2.0 , 2.0 ], [4.0 , 1.0 , 2.0 , 2.0 ]]).T
98100 preds = ISM_predict (seq , model , compare_func = None )
99101 assert np .allclose (preds .values , expected_preds )
100102 preds = ISM_predict (seq , model , compare_func = "log2FC" )
101103 assert np .allclose (preds .values , np .log2 (expected_preds / 4 ))
102104
105+ # Multiple sequences
106+ seqs = ["AAA" , "CCC" ]
107+ expected_preds = np .expand_dims (
108+ np .array (
109+ [
110+ [
111+ [4.0 , 2.0 , 2.6666667 , 2.6666667 ],
112+ [4.0 , 2.0 , 2.6666667 , 2.6666667 ],
113+ [4.0 , 2.0 , 2.6666667 , 2.6666667 ],
114+ ],
115+ [
116+ [0.0 , - 2.0 , - 1.3333334 , - 1.3333334 ],
117+ [0.0 , - 2.0 , - 1.3333334 , - 1.3333334 ],
118+ [0.0 , - 2.0 , - 1.3333334 , - 1.3333334 ],
119+ ],
120+ ]
121+ ),
122+ (3 , 4 ),
123+ )
124+ preds = ISM_predict (seqs , model , compare_func = None , return_df = False )
125+ assert np .allclose (preds , expected_preds )
126+ preds = ISM_predict (seqs , model , compare_func = "log2FC" , return_df = False )
127+ assert np .allclose (
128+ preds , np .log2 (np .stack ([expected_preds [0 ] / 4 , - expected_preds [1 ] / 2 ]))
129+ )
130+
103131
104132def test_get_attributions ():
105133 seq = generate_random_sequences (n = 1 , seq_len = 50 , seed = 0 , output_format = "strings" )[0 ]
0 commit comments