Skip to content

Commit db4a57b

Browse files
committed
Merge branch 'main' into bump-golangci-lint
* main: fix(mongodb): replica set initialization & connection handling (testcontainers#2984)
2 parents 7c186d0 + 675acc3 commit db4a57b

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

modules/mongodb/mongodb.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
_ "embed"
77
"errors"
88
"fmt"
9+
"net"
10+
"net/url"
911
"time"
1012

1113
"github.com/testcontainers/testcontainers-go"
@@ -125,10 +127,23 @@ func (c *MongoDBContainer) ConnectionString(ctx context.Context) (string, error)
125127
if err != nil {
126128
return "", err
127129
}
130+
u := url.URL{
131+
Scheme: "mongodb",
132+
Host: net.JoinHostPort(host, port.Port()),
133+
Path: "/",
134+
}
135+
128136
if c.username != "" && c.password != "" {
129-
return fmt.Sprintf("mongodb://%s:%s@%s:%s", c.username, c.password, host, port.Port()), nil
137+
u.User = url.UserPassword(c.username, c.password)
138+
}
139+
140+
if c.replicaSet != "" {
141+
q := url.Values{}
142+
q.Add("replicaSet", c.replicaSet)
143+
u.RawQuery = q.Encode()
130144
}
131-
return c.Endpoint(ctx, "mongodb")
145+
146+
return u.String(), nil
132147
}
133148

134149
func setupEntrypointForAuth(req *testcontainers.GenericContainerRequest) {
@@ -186,7 +201,6 @@ func initiateReplicaSet(req *testcontainers.GenericContainerRequest, cli mongoCl
186201
replSetName,
187202
ip,
188203
)
189-
190204
return wait.ForExec(cmd).WaitUntilReady(ctx, c)
191205
},
192206
},

modules/mongodb/mongodb_test.go

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package mongodb_test
22

33
import (
44
"context"
5+
"fmt"
6+
"net/url"
57
"testing"
68

79
"github.com/stretchr/testify/require"
@@ -125,18 +127,52 @@ func TestMongoDB(t *testing.T) {
125127
endpoint, err := mongodbContainer.ConnectionString(ctx)
126128
require.NoError(tt, err)
127129

128-
// Force direct connection to the container to avoid the replica set
129-
// connection string that is returned by the container itself when
130-
// using the replica set option.
130+
// Force direct connection to the container.
131131
mongoClient, err := mongo.Connect(ctx, options.Client().ApplyURI(endpoint).SetDirect(true))
132132
require.NoError(tt, err)
133133

134134
err = mongoClient.Ping(ctx, nil)
135135
require.NoError(tt, err)
136-
require.Equal(t, "test", mongoClient.Database("test").Name())
136+
require.Equal(tt, "test", mongoClient.Database("test").Name())
137137

138-
_, err = mongoClient.Database("testcontainer").Collection("test").InsertOne(context.Background(), bson.M{})
138+
// Basic insert test.
139+
_, err = mongoClient.Database("testcontainer").Collection("test").InsertOne(ctx, bson.M{})
139140
require.NoError(tt, err)
141+
142+
// If the container is configured with a replica set, run the change stream test.
143+
if hasReplica, _ := hasReplicaSet(endpoint); hasReplica {
144+
coll := mongoClient.Database("test").Collection("changes")
145+
stream, err := coll.Watch(ctx, mongo.Pipeline{})
146+
require.NoError(tt, err)
147+
defer stream.Close(ctx)
148+
149+
doc := bson.M{"message": "hello change streams"}
150+
_, err = coll.InsertOne(ctx, doc)
151+
require.NoError(tt, err)
152+
153+
require.True(tt, stream.Next(ctx))
154+
var changeEvent bson.M
155+
err = stream.Decode(&changeEvent)
156+
require.NoError(tt, err)
157+
158+
opType, ok := changeEvent["operationType"].(string)
159+
require.True(tt, ok, "Expected operationType field")
160+
require.Equal(tt, "insert", opType, "Expected operationType to be 'insert'")
161+
162+
fullDoc, ok := changeEvent["fullDocument"].(bson.M)
163+
require.True(tt, ok, "Expected fullDocument field")
164+
require.Equal(tt, "hello change streams", fullDoc["message"])
165+
}
140166
})
141167
}
142168
}
169+
170+
// hasReplicaSet checks if the connection string includes a replicaSet query parameter.
171+
func hasReplicaSet(connStr string) (bool, error) {
172+
u, err := url.Parse(connStr)
173+
if err != nil {
174+
return false, fmt.Errorf("parse connection string: %w", err)
175+
}
176+
q := u.Query()
177+
return q.Get("replicaSet") != "", nil
178+
}

0 commit comments

Comments
 (0)