16
16
17
17
package com .mongodb .internal .connection ;
18
18
19
+ import com .mongodb .ClusterFixture ;
19
20
import com .mongodb .ConnectionString ;
20
21
import com .mongodb .MongoClientSettings ;
21
22
import com .mongodb .MongoCommandException ;
41
42
import org .bson .Document ;
42
43
import org .junit .jupiter .api .AfterEach ;
43
44
import org .junit .jupiter .api .BeforeEach ;
45
+ import org .junit .jupiter .api .DisplayName ;
44
46
import org .junit .jupiter .api .Test ;
45
47
import org .junit .jupiter .params .ParameterizedTest ;
46
48
import org .junit .jupiter .params .provider .Arguments ;
47
49
import org .junit .jupiter .params .provider .MethodSource ;
48
- import org .junit .jupiter .params .provider .ValueSource ;
49
50
50
51
import java .io .IOException ;
51
52
import java .lang .reflect .Field ;
79
80
import static com .mongodb .MongoCredential .TOKEN_RESOURCE_KEY ;
80
81
import static com .mongodb .assertions .Assertions .assertNotNull ;
81
82
import static com .mongodb .testing .MongoAssertions .assertCause ;
82
- import static java .lang .Math .min ;
83
83
import static java .lang .String .format ;
84
84
import static java .lang .System .getenv ;
85
85
import static java .util .Arrays .asList ;
@@ -215,9 +215,9 @@ public void test2p1ValidCallbackInputs() {
215
215
+ " expectedTimeoutThreshold={3}" )
216
216
@ MethodSource
217
217
void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet (final String testName ,
218
- final int timeoutMs ,
219
- final int serverSelectionTimeoutMS ,
220
- final int expectedTimeoutThreshold ) {
218
+ final long timeoutMs ,
219
+ final long serverSelectionTimeoutMS ,
220
+ final long expectedTimeoutThreshold ) {
221
221
TestCallback callback1 = createCallback ();
222
222
223
223
OidcCallback callback2 = (context ) -> {
@@ -242,40 +242,50 @@ void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
242
242
assertEquals (1 , callback1 .getInvocations ());
243
243
long elapsed = msElapsedSince (start );
244
244
245
- assertFalse (elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min (serverSelectionTimeoutMS , timeoutMs )),
245
+
246
+ assertFalse (elapsed > minTimeout (timeoutMs , serverSelectionTimeoutMS ),
246
247
format ("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
247
248
+ "This indicates that the callback was not called with the expected timeout." ,
248
- min (serverSelectionTimeoutMS , timeoutMs ),
249
- elapsed ));
249
+ elapsed ,
250
+ minTimeout (timeoutMs , serverSelectionTimeoutMS )));
251
+
250
252
}
251
253
}
252
254
253
255
private static Stream <Arguments > testValidCallbackInputsTimeoutWhenTimeoutMsIsSet () {
256
+ long rtt = ClusterFixture .getPrimaryRTT ();
254
257
return Stream .of (
255
258
Arguments .of ("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS" ,
256
- 1000 , // timeoutMS
257
- 500 , // serverSelectionTimeoutMS
258
- 499 ), // expectedTimeoutThreshold
259
+ 1000 + rtt , // timeoutMS
260
+ 500 + rtt , // serverSelectionTimeoutMS
261
+ 499 + rtt ), // expectedTimeoutThreshold
259
262
Arguments .of ("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS" ,
260
- 500 , // timeoutMS
261
- 1000 , // serverSelectionTimeoutMS
262
- 499 ), // expectedTimeoutThreshold
263
+ 500 + rtt , // timeoutMS
264
+ 1000 + rtt , // serverSelectionTimeoutMS
265
+ 499 + rtt ), // expectedTimeoutThreshold
266
+ Arguments .of ("timeoutMS honored for oidc callback if serverSelectionTimeoutMS is infinite" ,
267
+ 500 + rtt , // timeoutMS
268
+ -1 , // serverSelectionTimeoutMS
269
+ 499 + rtt ), // expectedTimeoutThreshold,
263
270
Arguments .of ("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0" ,
264
271
0 , // infinite timeoutMS
265
- 500 , // serverSelectionTimeoutMS
266
- 499 ) // expectedTimeoutThreshold
272
+ 500 + rtt , // serverSelectionTimeoutMS
273
+ 499 + rtt ) // expectedTimeoutThreshold
267
274
);
268
275
}
269
276
270
277
// Not a prose test
271
- @ ParameterizedTest (name = "test callback timeout when server selection timeout is "
272
- + "infinite and timeoutMs is set to {0}" )
273
- @ ValueSource (ints = {0 , 100 })
274
- void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet (final int timeoutMs ) {
278
+ @ Test
279
+ @ DisplayName ("test callback timeout when serverSelectionTimeoutMS and timeoutMS are infinite" )
280
+ void testCallbackTimeoutWhenServerSelectionTimeoutMsIsInfiniteTimeoutMsIsSet () {
275
281
TestCallback callback1 = createCallback ();
282
+ Duration expectedTimeout = ChronoUnit .FOREVER .getDuration ();
276
283
277
284
OidcCallback callback2 = (context ) -> {
278
- assertEquals (context .getTimeout (), ChronoUnit .FOREVER .getDuration ());
285
+ assertEquals (expectedTimeout , context .getTimeout (),
286
+ format ("Expected timeout to be infinite (%s), but was %s" ,
287
+ expectedTimeout , context .getTimeout ()));
288
+
279
289
return callback1 .onRequest (context );
280
290
};
281
291
@@ -284,7 +294,7 @@ void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final
284
294
builder .serverSelectionTimeout (
285
295
-1 , // -1 means infinite
286
296
TimeUnit .MILLISECONDS ))
287
- .timeout (timeoutMs , TimeUnit .MILLISECONDS )
297
+ .timeout (0 , TimeUnit .MILLISECONDS )
288
298
.build ();
289
299
290
300
try (MongoClient mongoClient = createMongoClient (clientSettings )) {
@@ -1242,4 +1252,10 @@ public TestCallback createHumanCallback() {
1242
1252
private long msElapsedSince (final long timeOfStart ) {
1243
1253
return TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - timeOfStart );
1244
1254
}
1255
+
1256
+ private static long minTimeout (final long timeoutMs , final long serverSelectionTimeoutMS ) {
1257
+ long timeoutMsEffective = timeoutMs != 0 ? timeoutMs : Long .MAX_VALUE ;
1258
+ long serverSelectionTimeoutMSEffective = serverSelectionTimeoutMS != -1 ? serverSelectionTimeoutMS : Long .MAX_VALUE ;
1259
+ return Math .min (timeoutMsEffective , serverSelectionTimeoutMSEffective );
1260
+ }
1245
1261
}
0 commit comments