Skip to content

Commit e07d26b

Browse files
committed
Fix concurrent usage of Wilcoxon test
The access to the memoized cache is not thread-safe, is always cleared first and results in bugs when run concurrently.
1 parent 6e2e754 commit e07d26b

File tree

1 file changed

+66
-127
lines changed

1 file changed

+66
-127
lines changed
Lines changed: 66 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/*
22
* The baseCode project
3-
*
3+
*
44
* Copyright (c) 2006 University of British Columbia
5-
*
5+
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
77
* you may not use this file except in compliance with the License.
88
* You may obtain a copy of the License at
@@ -18,49 +18,43 @@
1818
*/
1919
package ubic.basecode.math;
2020

21+
import cern.colt.list.DoubleArrayList;
22+
import cern.jet.math.Arithmetic;
23+
import cern.jet.stat.Probability;
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
2127
import java.math.BigInteger;
2228
import java.util.Collections;
29+
import java.util.HashMap;
2330
import java.util.List;
2431
import java.util.Map;
25-
import java.util.concurrent.ConcurrentHashMap;
26-
27-
import org.slf4j.Logger;
28-
import org.slf4j.LoggerFactory;
29-
30-
import cern.colt.list.DoubleArrayList;
31-
import cern.jet.math.Arithmetic;
32-
import cern.jet.stat.Probability;
3332

3433
/**
3534
* Implements methods from supplementary file I of "Comparing functional annotation analyses with Catmap", Thomas
3635
* Breslin, Patrik Ed�n and Morten Krogh, BMC Bioinformatics 2004, 5:193 doi:10.1186/1471-2105-5-193
3736
* <p>
3837
* Note that in the Catmap code, zero-based ranks are used, but these are converted to one-based before computation of
3938
* pvalues. Therefore this code uses one-based ranks throughout.
40-
*
39+
*
4140
* @author pavlidis
4241
* @version Id
4342
* @see ROC
4443
*/
4544
public class Wilcoxon {
4645

47-
private static final Map<CacheKey, BigInteger> cache = new ConcurrentHashMap<CacheKey, BigInteger>();
4846

4947
/**
5048
* For smaller sample sizes, we compute exactly. Below 1e5 we start to notice some loss of precision (like one part
5149
* in 1e5). Setting this too high really slows things down for high-throughput applications.
5250
*/
53-
private static long LIMIT_FOR_APPROXIMATION = 100000L;
51+
private static final long LIMIT_FOR_APPROXIMATION = 100000L;
5452

55-
private static Logger log = LoggerFactory.getLogger( Wilcoxon.class );
53+
private static final Logger log = LoggerFactory.getLogger( Wilcoxon.class );
5654

5755
/**
5856
* Convenience method that computes a p-value using input of two double arrays. They must not contain missing values
5957
* or ties.
60-
*
61-
* @param a
62-
* @param b
63-
* @return
6458
*/
6559
public static double exactWilcoxonP( double[] a, double[] b ) {
6660
int fullLength = a.length + b.length;
@@ -80,12 +74,6 @@ public static double exactWilcoxonP( double[] a, double[] b ) {
8074
return pExact( fullLength, a.length, aSum );
8175
}
8276

83-
/**
84-
* @param N
85-
* @param n
86-
* @param R
87-
* @return
88-
*/
8977
public static double exactWilcoxonP( int N, int n, int R ) {
9078
if ( R > LIMIT_FOR_APPROXIMATION ) {
9179
throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of R will fail." );
@@ -95,22 +83,16 @@ public static double exactWilcoxonP( int N, int n, int R ) {
9583

9684
/**
9785
* Only use when you know there are no ties.
98-
*
99-
* @param N
100-
* @param n
101-
* @param R
102-
* @return
10386
*/
10487
public static double wilcoxonP( int N, int n, long R ) {
10588
return wilcoxonP( N, n, R, false );
10689
}
10790

10891
/**
109-
* @param N number of all Items
110-
* @param n number of class Items
111-
* @param R rankSum for items in the class. (one-based)
92+
* @param N number of all Items
93+
* @param n number of class Items
94+
* @param R rankSum for items in the class. (one-based)
11295
* @param ties set to true if you know there are ties
113-
* @return
11496
*/
11597
public static double wilcoxonP( int N, int n, long R, boolean ties ) {
11698

@@ -120,7 +102,7 @@ public static double wilcoxonP( int N, int n, long R, boolean ties ) {
120102

121103
if ( ( !ties )
122104
&& ( ( ( long ) N * ( long ) n <= LIMIT_FOR_APPROXIMATION && n * R <= LIMIT_FOR_APPROXIMATION && ( long ) N
123-
* ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
105+
* ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
124106
if ( log.isDebugEnabled() ) log.debug( "Using exact method (" + N * n * R + ")" );
125107
return pExact( N, n, ( int ) R );
126108
}
@@ -137,9 +119,8 @@ public static double wilcoxonP( int N, int n, long R, boolean ties ) {
137119
}
138120

139121
/**
140-
* @param N total number of items (in and not in the class)
122+
* @param N total number of items (in and not in the class)
141123
* @param ranks of items in the class (one-based)
142-
* @return
143124
*/
144125
public static double wilcoxonP( int N, List<Double> ranks ) {
145126

@@ -164,57 +145,34 @@ public static double wilcoxonP( int N, List<Double> ranks ) {
164145
return wilcoxonP( N, ranks.size(), rankSum, ties );
165146
}
166147

167-
private static void addToCache( long N, long n, long R, BigInteger value ) {
168-
cache.put( new CacheKey( N, n, R ), value );
169-
}
170-
171-
/**
172-
* @param n0
173-
* @param n02
174-
* @param r0
175-
* @return
176-
*/
177-
private static boolean cacheContains( long N, long n, long R ) {
178-
return cache.containsKey( new CacheKey( N, n, R ) );
179-
}
180-
181148
/**
182149
* Direct port from catmap code. Exact computation of the number of ways n items can be drawn from a total of N
183150
* items with a rank sum of R or better (lower).
184-
*
185-
* @param N0
186-
* @param n0
151+
*
187152
* @param R0 rank sum, 1-based (best rank is 1)
188-
* @return
189153
*/
190154
private static BigInteger computeA__( int N0, int n0, int R0 ) {
191-
cache.clear();
192-
if ( R0 < N0 ) N0 = R0;
155+
Map<CacheKey, BigInteger> cache = new HashMap<>();
193156

194-
// if ( cacheContains( N0, n0, R0 ) ) {
195-
// return getFromCache( N0, n0, R0 );
196-
// }
157+
if ( R0 < N0 ) N0 = R0;
197158

198159
if ( N0 == 0 && n0 == 0 ) return BigInteger.ONE;
199160

200161
for ( int N = 1; N <= N0; N++ ) {
201-
if ( N > 2 ) removeFromCache( N - 2 );
202-
203162
/* n has to be less than N */
204163
long min_n = Math.max( 0, n0 + N - N0 );
205164
long max_n = Math.min( n0, N );
206165

207-
assert min_n >= 0;
208166
assert max_n >= min_n;
209167

210168
for ( long n = min_n; n <= max_n; n++ ) {
211169

212170
/* The rank sum is in the interval n(n+1)/2 to n(2N-n+1)/2. Other values need not be looked at. */
213171
long bestPossibleRankSum = n * ( n + 1 ) / 2;
214-
long worstPossibleRankSum = n * ( 2 * N - n + 1 ) / 2;
172+
long worstPossibleRankSum = n * ( 2L * N - n + 1 ) / 2;
215173

216174
/* Ensure value looked at is valid for the original set of parameters. */
217-
long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1 ) * ( N0 - N ) / 2 );
175+
long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1L ) * ( N0 - N ) / 2 );
218176
long max_r = Math.min( worstPossibleRankSum, R0 );
219177

220178
assert min_r >= 0;
@@ -227,59 +185,40 @@ private static BigInteger computeA__( int N0, int n0, int R0 ) {
227185
n0, R0 );
228186

229187
/* R greater than this, have already computed it in parts */
230-
long foo = n * ( 2 * N - n - 1 ) / 2;
188+
long foo = n * ( 2L * N - n - 1 ) / 2;
231189

232190
/* R less than this, we have already computed it in parts */
233191
long bar = N + ( n - 1 ) * n / 2;
234192

235193
for ( long r = min_r; r <= max_r; r++ ) {
236194

237195
if ( n == 0 || n == N || r == bestPossibleRankSum ) {
238-
addToCache( N, n, r, BigInteger.ONE );
196+
addToCache( cache, N, n, r, BigInteger.ONE );
239197

240198
} else if ( r > foo ) {
241-
addToCache( N, n, r, getFromCache( N - 1, n, foo ).add( getFromCache( N - 1, n - 1, r - N ) ) );
199+
addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, foo ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
242200

243201
} else if ( r < bar ) {
244-
addToCache( N, n, r, getFromCache( N - 1, n, r ) );
202+
addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ) );
245203

246204
} else {
247-
addToCache( N, n, r, getFromCache( N - 1, n, r ).add( getFromCache( N - 1, n - 1, r - N ) ) );
205+
addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
248206
}
249207
}
250208
}
251209
}
252-
return getFromCache( N0, n0, R0 );
253-
}
254210

255-
/**
256-
* @param N
257-
* @param n
258-
* @param R
259-
* @return
260-
*/
261-
private static BigInteger getFromCache( long N, long n, long R ) {
262-
263-
if ( !cacheContains( N, n, R ) ) {
264-
throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
265-
}
266-
return cache.get( new CacheKey( N, n, R ) );
211+
return getFromCache( cache, N0, n0, R0 );
267212
}
268213

269214
/**
270-
* @param N
271-
* @param n
272-
* @param r rank sum, 1-based (best rank is 1).
273-
* @return
215+
* @param R rank sum, 1-based (best rank is 1).
274216
*/
275217
private static double pExact( int N, int n, int R ) {
276218
return computeA__( N, n, R ).doubleValue() / Arithmetic.binomial( N, n );
277219
}
278220

279221
/**
280-
* @param N
281-
* @param n
282-
* @param R
283222
* @return Upper-tail probability for Wilcoxon rank-sum test.
284223
*/
285224
private static double pGaussian( long N, long n, long R ) {
@@ -292,11 +231,6 @@ private static double pGaussian( long N, long n, long R ) {
292231

293232
/**
294233
* Directly ported from catmap.
295-
*
296-
* @param N
297-
* @param n
298-
* @param R
299-
* @return
300234
*/
301235
private static double pVolume( int N, int n, long R ) {
302236

@@ -337,46 +271,51 @@ private static double pVolume( int N, int n, long R ) {
337271

338272
}
339273

340-
/**
341-
* @param i
342-
*/
343-
private static void removeFromCache( int N ) {
344-
cache.remove( N );
274+
private static void addToCache( Map<CacheKey, BigInteger> cache, long N, long n, long R, BigInteger value ) {
275+
cache.put( new CacheKey( N, n, R ), value );
345276
}
346277

347-
}
278+
private static boolean cacheContains( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
279+
return cache.containsKey( new CacheKey( N, n, R ) );
280+
}
348281

349-
class CacheKey {
350-
private long n;
351-
private long N;
352-
private long R;
282+
private static BigInteger getFromCache( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
353283

354-
public CacheKey( long N, long n, long R ) {
355-
super();
356-
this.N = N;
357-
this.n = n;
358-
this.R = R;
284+
if ( !cacheContains( cache, N, n, R ) ) {
285+
throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
286+
}
287+
return cache.get( new CacheKey( N, n, R ) );
359288
}
360289

361-
@Override
362-
public boolean equals( Object obj ) {
363-
CacheKey other = ( CacheKey ) obj;
290+
private static class CacheKey {
291+
private final long n;
292+
private final long N;
293+
private final long R;
364294

365-
if ( N != other.N ) return false;
366-
if ( n != other.n ) return false;
367-
if ( R != other.R ) return false;
295+
public CacheKey( long N, long n, long R ) {
296+
this.N = N;
297+
this.n = n;
298+
this.R = R;
299+
}
368300

369-
return true;
370-
}
301+
@Override
302+
public boolean equals( Object obj ) {
303+
if ( obj instanceof CacheKey ) {
304+
CacheKey other = ( CacheKey ) obj;
305+
return N == other.N && n == other.n && R == other.R;
306+
} else {
307+
return false;
308+
}
309+
}
371310

372-
@Override
373-
public int hashCode() {
374-
final int prime = 31;
375-
long result = 1;
376-
result = prime * result + N;
377-
result = prime * result + n;
378-
result = prime * result + R;
379-
return ( int ) result; // problem: overflows?
311+
@Override
312+
public int hashCode() {
313+
final int prime = 31;
314+
long result = 1;
315+
result = prime * result + N;
316+
result = prime * result + n;
317+
result = prime * result + R;
318+
return ( int ) result; // problem: overflows?
319+
}
380320
}
381-
382321
}

0 commit comments

Comments
 (0)