Skip to content

GODRIVER-3473 Add client-side validation for maxAwaitTime+op timeout #2130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions internal/mongoutil/mongoutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
package mongoutil

import (
"context"
"reflect"
"time"

"go.mongodb.org/mongo-driver/v2/mongo/options"
)
Expand Down Expand Up @@ -83,3 +85,23 @@ func HostsFromURI(uri string) ([]string, error) {

return opts.Hosts, nil
}

// ValidMaxAwaitTimeMS will return "false" if maxAwaitTimeMS is set, timeoutMS
// is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to
// timeoutMS. Otherwise, the timeouts are valid.
func ValidMaxAwaitTimeMS(ctx context.Context, timeout, maxAwaiTime *time.Duration) bool {
if maxAwaiTime == nil {
return true
}

if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
timeout = &ctxTimeout
}

if timeout == nil {
return true
}

return *timeout <= 0 || *maxAwaiTime < *timeout
}
78 changes: 78 additions & 0 deletions internal/mongoutil/mongoutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
package mongoutil

import (
"context"
"strings"
"testing"
"time"

"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/ptrutil"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)

Expand All @@ -32,3 +36,77 @@ func BenchmarkNewOptions(b *testing.B) {
}
})
}

func TestValidChangeStreamTimeouts(t *testing.T) {
tests := []struct {
name string
parent context.Context
maxAwaitTimeout, timeout *time.Duration
wantTimeout time.Duration
want bool
}{
{
name: "no context deadline and no timeouts",
parent: context.Background(),
maxAwaitTimeout: nil,
timeout: nil,
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTimeout",
parent: context.Background(),
maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)),
timeout: nil,
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and timeout",
parent: context.Background(),
maxAwaitTimeout: nil,
timeout: ptrutil.Ptr(time.Duration(1)),
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTime gt timeout",
parent: context.Background(),
maxAwaitTimeout: ptrutil.Ptr(time.Duration(2)),
timeout: ptrutil.Ptr(time.Duration(1)),
wantTimeout: 0,
want: false,
},
{
name: "no context deadline and maxAwaitTime lt timeout",
parent: context.Background(),
maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)),
timeout: ptrutil.Ptr(time.Duration(2)),
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTime eq timeout",
parent: context.Background(),
maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)),
timeout: ptrutil.Ptr(time.Duration(1)),
wantTimeout: 0,
want: false,
},
{
name: "no context deadline and maxAwaitTime with negative timeout",
parent: context.Background(),
maxAwaitTimeout: ptrutil.Ptr(time.Duration(1)),
timeout: ptrutil.Ptr(time.Duration(-1)),
wantTimeout: 0,
want: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := ValidMaxAwaitTimeMS(test.parent, test.timeout, test.maxAwaitTimeout)
assert.Equal(t, test.want, got)
})
}
}
14 changes: 1 addition & 13 deletions internal/spectest/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-operation-timeoutMS.json/timeoutMS_applied_to_withTransaction",
"TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set",
Expand Down Expand Up @@ -817,19 +816,8 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns",
},

// GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the
// Go Driver does not correctly implement the following validation for
// tailable awaitData cursors:
//
// Drivers MUST error if this option is set, timeoutMS is set to a
// non-zero value, and maxAwaitTimeMS is greater than or equal to
// timeoutMS.
//
// Once GODRIVER-3473 is completed, we can continue running these tests.
"When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": {
"Address CSOT Compliance Issue in Timeout Handling for Cursor Constructors (GODRIVER-3480)": {
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
"TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS",
},
}

Expand Down
37 changes: 8 additions & 29 deletions mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"fmt"
"reflect"
"strconv"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/csot"
Expand Down Expand Up @@ -103,33 +102,6 @@ type changeStreamConfig struct {
crypt driver.Crypt
}

// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set,
// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or
// equal to timeoutMS. Otherwise, the timeouts are valid.
func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool {
if cs.options == nil || cs.client == nil {
return true
}

maxAwaitTime := cs.options.MaxAwaitTime
timeout := cs.client.timeout

if maxAwaitTime == nil {
return true
}

if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
timeout = &ctxTimeout
}

if timeout == nil {
return true
}

return *timeout <= 0 || *maxAwaitTime < *timeout
}

func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{},
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
if ctx == nil {
Expand All @@ -145,6 +117,10 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
return nil, err
}

if c := config.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) {
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
}

cs := &ChangeStream{
client: config.client,
bsonOpts: config.bsonOpts,
Expand Down Expand Up @@ -696,7 +672,10 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool {
}

func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) {
if !validChangeStreamTimeouts(ctx, cs) {
// Sending a maxAwaitTimeMS option to the server that is less than or equal to
// the operation timeout will result in a socket timeout error.
if csOpts := cs.options; csOpts != nil && cs.client != nil &&
!mongoutil.ValidMaxAwaitTimeMS(ctx, cs.client.timeout, csOpts.MaxAwaitTime) {
cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout")

return
Expand Down
96 changes: 0 additions & 96 deletions mongo/change_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
package mongo

import (
"context"
"testing"
"time"

"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)

func TestChangeStream(t *testing.T) {
Expand All @@ -30,96 +27,3 @@ func TestChangeStream(t *testing.T) {
assert.Nil(t, err, "Close error: %v", err)
})
}

func TestValidChangeStreamTimeouts(t *testing.T) {
t.Parallel()

newDurPtr := func(dur time.Duration) *time.Duration {
return &dur
}

tests := []struct {
name string
parent context.Context
maxAwaitTimeout, timeout *time.Duration
wantTimeout time.Duration
want bool
}{
{
name: "no context deadline and no timeouts",
parent: context.Background(),
maxAwaitTimeout: nil,
timeout: nil,
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTimeout",
parent: context.Background(),
maxAwaitTimeout: newDurPtr(1),
timeout: nil,
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and timeout",
parent: context.Background(),
maxAwaitTimeout: nil,
timeout: newDurPtr(1),
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTime gt timeout",
parent: context.Background(),
maxAwaitTimeout: newDurPtr(2),
timeout: newDurPtr(1),
wantTimeout: 0,
want: false,
},
{
name: "no context deadline and maxAwaitTime lt timeout",
parent: context.Background(),
maxAwaitTimeout: newDurPtr(1),
timeout: newDurPtr(2),
wantTimeout: 0,
want: true,
},
{
name: "no context deadline and maxAwaitTime eq timeout",
parent: context.Background(),
maxAwaitTimeout: newDurPtr(1),
timeout: newDurPtr(1),
wantTimeout: 0,
want: false,
},
{
name: "no context deadline and maxAwaitTime with negative timeout",
parent: context.Background(),
maxAwaitTimeout: newDurPtr(1),
timeout: newDurPtr(-1),
wantTimeout: 0,
want: true,
},
}

for _, test := range tests {
test := test // Capture the range variable

t.Run(test.name, func(t *testing.T) {
t.Parallel()

cs := &ChangeStream{
options: &options.ChangeStreamOptions{
MaxAwaitTime: test.maxAwaitTimeout,
},
client: &Client{
timeout: test.timeout,
},
}

got := validChangeStreamTimeouts(test.parent, cs)
assert.Equal(t, test.want, got)
})
}
}
13 changes: 12 additions & 1 deletion mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,12 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
return nil, err
}

// Sending a maxAwaitTimeMS option to the server that is less than or equal to
// the operation timeout will result in a socket timeout error.
if c := a.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(a.ctx, c.timeout, args.MaxAwaitTime) {
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
}

cursorOpts := a.client.createBaseCursorOptions()

cursorOpts.MarshalValueEncoderFn = newEncoderFn(a.bsonOpts, a.registry)
Expand Down Expand Up @@ -1347,11 +1353,16 @@ func (coll *Collection) find(
omitMaxTimeMS bool,
args *options.FindOptions,
) (cur *Cursor, err error) {

if ctx == nil {
ctx = context.Background()
}

// Sending a maxAwaitTimeMS option to the server that is less than or equal to
// the operation timeout will result in a socket timeout error.
if c := coll.client; c != nil && !mongoutil.ValidMaxAwaitTimeMS(ctx, c.timeout, args.MaxAwaitTime) {
return nil, fmt.Errorf("MaxAwaitTime must be less than the operation timeout")
}

f, err := marshal(filter, coll.bsonOpts, coll.registry)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/batch_cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ func (bc *BatchCursor) getMore(ctx context.Context) {

bc.err = Operation{
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use
// send remaining TimeoutMS - minRoundTripTime allowing the server an
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then send
// remaining TimeoutMS - minRoundTripTime, allowing the server an
// opportunity to respond with an empty batch.
var maxTimeMS int64
if bc.maxAwaitTime != nil {
Expand Down
Loading