diff --git a/internal/mongoutil/mongoutil.go b/internal/mongoutil/mongoutil.go index 0345b96e8f..3675979b17 100644 --- a/internal/mongoutil/mongoutil.go +++ b/internal/mongoutil/mongoutil.go @@ -7,7 +7,9 @@ package mongoutil import ( + "context" "reflect" + "time" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -83,3 +85,23 @@ func HostsFromURI(uri string) ([]string, error) { return opts.Hosts, nil } + +// ValidMaxAwaitTimeMS will return "false" if maxAwaitTimeMS is set, timeoutMS +// is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to +// timeoutMS. Otherwise, the timeouts are valid. +func ValidMaxAwaitTimeMS(ctx context.Context, timeout, maxAwaiTime *time.Duration) bool { + if maxAwaiTime == nil { + return true + } + + if deadline, ok := ctx.Deadline(); ok { + ctxTimeout := time.Until(deadline) + timeout = &ctxTimeout + } + + if timeout == nil { + return true + } + + return *timeout <= 0 || *maxAwaiTime < *timeout +} diff --git a/internal/mongoutil/mongoutil_test.go b/internal/mongoutil/mongoutil_test.go index 661ee5f5bb..c9fa75720d 100644 --- a/internal/mongoutil/mongoutil_test.go +++ b/internal/mongoutil/mongoutil_test.go @@ -7,9 +7,13 @@ package mongoutil import ( + "context" "strings" "testing" + "time" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -32,3 +36,77 @@ func BenchmarkNewOptions(b *testing.B) { } }) } + +func TestValidChangeStreamTimeouts(t *testing.T) { + tests := []struct { + name string + parent context.Context + maxAwaitTimeout, timeout *time.Duration + wantTimeout time.Duration + want bool + }{ + { + name: "no context deadline and no timeouts", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTimeout", + parent: context.Background(), + maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)), + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and timeout", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime gt timeout", + parent: context.Background(), + maxAwaitTimeout: ptrutil.Ptr(time.Duration(2)), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime lt timeout", + parent: context.Background(), + maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)), + timeout: ptrutil.Ptr(time.Duration(2)), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime eq timeout", + parent: context.Background(), + maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime with negative timeout", + parent: context.Background(), + maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)), + timeout: ptrutil.Ptr(time.Duration(-1)), + wantTimeout: 0, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := ValidMaxAwaitTimeMS(test.parent, test.timeout, test.maxAwaitTimeout) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/internal/spectest/skip.go b/internal/spectest/skip.go index 70866530e1..0addb75e3d 100644 --- a/internal/spectest/skip.go +++ b/internal/spectest/skip.go @@ -582,7 +582,6 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-operation-timeoutMS.json/timeoutMS_applied_to_withTransaction", "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", @@ -817,19 +816,8 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns", }, - // GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the - // Go Driver does not correctly implement the following validation for - // tailable awaitData cursors: - // - // Drivers MUST error if this option is set, timeoutMS is set to a - // non-zero value, and maxAwaitTimeMS is greater than or equal to - // timeoutMS. - // - // Once GODRIVER-3473 is completed, we can continue running these tests. - "When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": { + "Address CSOT Compliance Issue in Timeout Handling for Cursor Constructors (GODRIVER-3480)": { "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", }, } diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 009e68e4e4..62aabf6688 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -12,7 +12,6 @@ import ( "fmt" "reflect" "strconv" - "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/csot" @@ -103,33 +102,6 @@ type changeStreamConfig struct { crypt driver.Crypt } -// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set, -// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or -// equal to timeoutMS. Otherwise, the timeouts are valid. -func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool { - if cs.options == nil || cs.client == nil { - return true - } - - maxAwaitTime := cs.options.MaxAwaitTime - timeout := cs.client.timeout - - if maxAwaitTime == nil { - return true - } - - if deadline, ok := ctx.Deadline(); ok { - ctxTimeout := time.Until(deadline) - timeout = &ctxTimeout - } - - if timeout == nil { - return true - } - - return *timeout <= 0 || *maxAwaitTime < *timeout -} - func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { if ctx == nil { @@ -145,6 +117,10 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in return nil, err } + if c := config.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) { + return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + } + cs := &ChangeStream{ client: config.client, bsonOpts: config.bsonOpts, @@ -696,7 +672,10 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { } func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { - if !validChangeStreamTimeouts(ctx, cs) { + // Sending a maxAwaitTimeMS option to the server that is less than or equal to + // the operation timeout will result in a socket timeout error. + if csOpts := cs.options; csOpts != nil && cs.client != nil && + !mongoutil.ValidMaxAwaitTimeMS(ctx, cs.client.timeout, csOpts.MaxAwaitTime) { cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") return diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index 8e722764a8..c2752a2c16 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -7,12 +7,9 @@ package mongo import ( - "context" "testing" - "time" "go.mongodb.org/mongo-driver/v2/internal/assert" - "go.mongodb.org/mongo-driver/v2/mongo/options" ) func TestChangeStream(t *testing.T) { @@ -30,96 +27,3 @@ func TestChangeStream(t *testing.T) { assert.Nil(t, err, "Close error: %v", err) }) } - -func TestValidChangeStreamTimeouts(t *testing.T) { - t.Parallel() - - newDurPtr := func(dur time.Duration) *time.Duration { - return &dur - } - - tests := []struct { - name string - parent context.Context - maxAwaitTimeout, timeout *time.Duration - wantTimeout time.Duration - want bool - }{ - { - name: "no context deadline and no timeouts", - parent: context.Background(), - maxAwaitTimeout: nil, - timeout: nil, - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTimeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: nil, - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and timeout", - parent: context.Background(), - maxAwaitTimeout: nil, - timeout: newDurPtr(1), - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTime gt timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(2), - timeout: newDurPtr(1), - wantTimeout: 0, - want: false, - }, - { - name: "no context deadline and maxAwaitTime lt timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(2), - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTime eq timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(1), - wantTimeout: 0, - want: false, - }, - { - name: "no context deadline and maxAwaitTime with negative timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(-1), - wantTimeout: 0, - want: true, - }, - } - - for _, test := range tests { - test := test // Capture the range variable - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - cs := &ChangeStream{ - options: &options.ChangeStreamOptions{ - MaxAwaitTime: test.maxAwaitTimeout, - }, - client: &Client{ - timeout: test.timeout, - }, - } - - got := validChangeStreamTimeouts(test.parent, cs) - assert.Equal(t, test.want, got) - }) - } -} diff --git a/mongo/collection.go b/mongo/collection.go index d7693c4245..80caa5dfab 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -954,6 +954,12 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption return nil, err } + // Sending a maxAwaitTimeMS option to the server that is less than or equal to + // the operation timeout will result in a socket timeout error. + if c := a.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(a.ctx, c.timeout, args.MaxAwaitTime) { + return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + } + cursorOpts := a.client.createBaseCursorOptions() cursorOpts.MarshalValueEncoderFn = newEncoderFn(a.bsonOpts, a.registry) @@ -1347,11 +1353,16 @@ func (coll *Collection) find( omitMaxTimeMS bool, args *options.FindOptions, ) (cur *Cursor, err error) { - if ctx == nil { ctx = context.Background() } + // Sending a maxAwaitTimeMS option to the server that is less than or equal to + // the operation timeout will result in a socket timeout error. + if c := coll.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) { + return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + } + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 6d6cd211a5..0a3ec2cda1 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -381,8 +381,8 @@ func (bc *BatchCursor) getMore(ctx context.Context) { bc.err = Operation{ CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { - // If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use - // send remaining TimeoutMS - minRoundTripTime allowing the server an + // If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then send + // remaining TimeoutMS - minRoundTripTime, allowing the server an // opportunity to respond with an empty batch. var maxTimeMS int64 if bc.maxAwaitTime != nil {