From bada7e676b08e2f11ef7fff46b5a52d60d2b7753 Mon Sep 17 00:00:00 2001 From: Joy Gao <17896160+jgao54@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:08:27 -1000 Subject: [PATCH 1/2] enable ShardedCluster + better code sharing for validation --- flow/connectors/mongo/validate.go | 27 ++---- flow/shared/mongo/commands.go | 84 ++++++++++++++++++ flow/shared/mongo/validation.go | 140 ++++++++++++++++++++---------- 3 files changed, 184 insertions(+), 67 deletions(-) create mode 100644 flow/shared/mongo/commands.go diff --git a/flow/connectors/mongo/validate.go b/flow/connectors/mongo/validate.go index bc5410433c..989fd11215 100644 --- a/flow/connectors/mongo/validate.go +++ b/flow/connectors/mongo/validate.go @@ -2,25 +2,20 @@ package connmongo import ( "context" - "errors" - "fmt" "github.com/PeerDB-io/peerdb/flow/generated/protos" shared_mongo "github.com/PeerDB-io/peerdb/flow/shared/mongo" ) func (c *MongoConnector) ValidateCheck(ctx context.Context) error { - version, err := c.GetVersion(ctx) - if err != nil { + if err := shared_mongo.ValidateServerCompatibility(ctx, c.client); err != nil { return err } - cmp, err := shared_mongo.CompareServerVersions(version, shared_mongo.MinSupportedVersion) - if err != nil { + + if err := shared_mongo.ValidateUserRoles(ctx, c.client); err != nil { return err } - if cmp == -1 { - return fmt.Errorf("require minimum mongo version %s", shared_mongo.MinSupportedVersion) - } + return nil } @@ -29,21 +24,9 @@ func (c *MongoConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F return nil } - if _, err := shared_mongo.GetReplSetGetStatus(ctx, c.client); err != nil { - return err - } - - serverStatus, err := shared_mongo.GetServerStatus(ctx, c.client) - if err != nil { + if err := shared_mongo.ValidateOplogRetention(ctx, c.client); err != nil { return err } - if serverStatus.StorageEngine.Name != "wiredTiger" { - return errors.New("storage engine must be 'wiredTiger'") - } - if serverStatus.OplogTruncation.OplogMinRetentionHours == 0 || - serverStatus.OplogTruncation.OplogMinRetentionHours < shared_mongo.MinOplogRetentionHours { - return errors.New("oplog retention must be set to >= 24 hours") - } return nil } diff --git a/flow/shared/mongo/commands.go b/flow/shared/mongo/commands.go new file mode 100644 index 0000000000..d9adb27f6e --- /dev/null +++ b/flow/shared/mongo/commands.go @@ -0,0 +1,84 @@ +package mongo + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +type BuildInfo struct { + Version string `bson:"version"` +} + +func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) { + return runCommand[BuildInfo](ctx, client, "buildInfo") +} + +type ReplSetGetStatus struct { + Set string `bson:"set"` + MyState int `bson:"myState"` +} + +func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) { + return runCommand[ReplSetGetStatus](ctx, client, "replSetGetStatus") +} + +type OplogTruncation struct { + OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"` +} + +type StorageEngine struct { + Name string `bson:"name"` +} + +type ServerStatus struct { + StorageEngine StorageEngine `bson:"storageEngine"` + OplogTruncation OplogTruncation `bson:"oplogTruncation"` +} + +func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) { + return runCommand[ServerStatus](ctx, client, "serverStatus") +} + +type ConnectionStatus struct { + AuthInfo AuthInfo `bson:"authInfo"` +} + +type AuthInfo struct { + AuthenticatedUserRoles []Role `bson:"authenticatedUserRoles"` +} + +type Role struct { + Role string `bson:"role"` + DB string `bson:"db"` +} + +func GetConnectionStatus(ctx context.Context, client *mongo.Client) (*ConnectionStatus, error) { + return runCommand[ConnectionStatus](ctx, client, "connectionStatus") +} + +type HelloResponse struct { + Msg string `bson:"msg,omitempty"` + Hosts []string `bson:"hosts,omitempty"` +} + +func GetHelloResponse(ctx context.Context, client *mongo.Client) (*HelloResponse, error) { + return runCommand[HelloResponse](ctx, client, "hello") +} + +func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (*T, error) { + singleResult := client.Database("admin").RunCommand(ctx, bson.D{ + bson.E{Key: command, Value: 1}, + }) + if singleResult.Err() != nil { + return nil, fmt.Errorf("'%s' failed: %v", command, singleResult.Err()) + } + + var result T + if err := singleResult.Decode(&result); err != nil { + return nil, fmt.Errorf("'%s' failed: %v", command, err) + } + return &result, nil +} diff --git a/flow/shared/mongo/validation.go b/flow/shared/mongo/validation.go index 752956ca07..7a1d96d7ce 100644 --- a/flow/shared/mongo/validation.go +++ b/flow/shared/mongo/validation.go @@ -2,75 +2,125 @@ package mongo import ( "context" + "errors" "fmt" - "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" ) const ( MinSupportedVersion = "5.1.0" MinOplogRetentionHours = 24 + + ReplicaSet = "ReplicaSet" + ShardedCluster = "ShardedCluster" ) -type BuildInfo struct { - Version string `bson:"version"` -} +func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error { + buildInfo, err := GetBuildInfo(ctx, client) + if err != nil { + return err + } -type ReplSetGetStatus struct { - Set string `bson:"set"` - MyState int `bson:"myState"` -} + if cmp, err := CompareServerVersions(buildInfo.Version, MinSupportedVersion); err != nil { + return err + } else if cmp < 0 { + return fmt.Errorf("require minimum mongo version %s", MinSupportedVersion) + } -type OplogTruncation struct { - OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"` -} + validateStorageEngine := func(instanceCtx context.Context, instanceClient *mongo.Client) error { + ss, err := GetServerStatus(instanceCtx, instanceClient) + if err != nil { + return err + } -type StorageEngine struct { - Name string `bson:"name"` -} + if ss.StorageEngine.Name != "wiredTiger" { + return errors.New("only wiredTiger storage engine is supported") + } + return nil + } -type ServerStatus struct { - StorageEngine StorageEngine `bson:"storageEngine"` - OplogTruncation OplogTruncation `bson:"oplogTruncation"` + topologyType, err := GetTopologyType(ctx, client) + if err != nil { + return err + } + + if topologyType == ReplicaSet { + return validateStorageEngine(ctx, client) + } else { + // TODO: run validation on shard + return nil + } } -func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{bson.E{Key: "buildInfo", Value: 1}}) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'buildInfo' command: %w", singleResult.Err()) +func ValidateUserRoles(ctx context.Context, client *mongo.Client) error { + RequiredRoles := []string{"readAnyDatabase", "clusterMonitor"} + + connectionStatus, err := GetConnectionStatus(ctx, client) + if err != nil { + return err } - var info BuildInfo - if err := singleResult.Decode(&info); err != nil { - return nil, fmt.Errorf("failed to decode BuildInfo: %w", err) + + hasRole := func(roles []Role, targetRole string) bool { + for _, role := range roles { + if role.Role == targetRole { + return true + } + } + return false } - return &info, nil + + for _, requiredRole := range RequiredRoles { + if !hasRole(connectionStatus.AuthInfo.AuthenticatedUserRoles, requiredRole) { + return fmt.Errorf("missing required role: %s", requiredRole) + } + } + + return nil } -func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{ - bson.E{Key: "replSetGetStatus", Value: 1}, - }) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'replSetGetStatus' command: %w", singleResult.Err()) +func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error { + validateOplogRetention := func(instanceCtx context.Context, instanceClient *mongo.Client) error { + ss, err := GetServerStatus(instanceCtx, instanceClient) + if err != nil { + return err + } + if ss.OplogTruncation.OplogMinRetentionHours == 0 || + ss.OplogTruncation.OplogMinRetentionHours < MinOplogRetentionHours { + return fmt.Errorf("oplog retention must be set to >= 24 hours, but got %f", + ss.OplogTruncation.OplogMinRetentionHours) + } + return nil } - var status ReplSetGetStatus - if err := singleResult.Decode(&status); err != nil { - return nil, fmt.Errorf("failed to decode ReplSetGetStatus: %w", err) + + topology, err := GetTopologyType(ctx, client) + if err != nil { + return err + } + if topology == ReplicaSet { + return validateOplogRetention(ctx, client) + } else { + // TODO: run validation on shard + return nil } - return &status, nil } -func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{ - bson.E{Key: "serverStatus", Value: 1}, - }) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'serverStatus' command: %w", singleResult.Err()) +func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error) { + hello, err := GetHelloResponse(ctx, client) + if err != nil { + return "", err } - var status ServerStatus - if err := singleResult.Decode(&status); err != nil { - return nil, fmt.Errorf("failed to decode ServerStatus: %w", err) + + // Only replica set has 'hosts' field + // https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.hosts + if len(hello.Hosts) > 0 { + return ReplicaSet, nil + } + + // Only sharded cluster has 'msg' field, and equals to 'isdbgrid' + // https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.msg + if hello.Msg == "isdbgrid" { + return ShardedCluster, nil } - return &status, nil + return "", errors.New("topology type must be ReplicaSet or ShardedCluster") } From 6b262de5a67f4910a7d2a15bfa7f3e7c2107d433 Mon Sep 17 00:00:00 2001 From: Joy Gao <17896160+jgao54@users.noreply.github.com> Date: Fri, 15 Aug 2025 08:16:26 -1000 Subject: [PATCH 2/2] mongo: pr review --- flow/shared/mongo/commands.go | 20 ++++++++++---------- flow/shared/mongo/validation.go | 18 ++++++------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/flow/shared/mongo/commands.go b/flow/shared/mongo/commands.go index d9adb27f6e..a48b0dd82f 100644 --- a/flow/shared/mongo/commands.go +++ b/flow/shared/mongo/commands.go @@ -12,7 +12,7 @@ type BuildInfo struct { Version string `bson:"version"` } -func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) { +func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) { return runCommand[BuildInfo](ctx, client, "buildInfo") } @@ -21,7 +21,7 @@ type ReplSetGetStatus struct { MyState int `bson:"myState"` } -func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) { +func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (ReplSetGetStatus, error) { return runCommand[ReplSetGetStatus](ctx, client, "replSetGetStatus") } @@ -38,7 +38,7 @@ type ServerStatus struct { OplogTruncation OplogTruncation `bson:"oplogTruncation"` } -func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) { +func GetServerStatus(ctx context.Context, client *mongo.Client) (ServerStatus, error) { return runCommand[ServerStatus](ctx, client, "serverStatus") } @@ -55,7 +55,7 @@ type Role struct { DB string `bson:"db"` } -func GetConnectionStatus(ctx context.Context, client *mongo.Client) (*ConnectionStatus, error) { +func GetConnectionStatus(ctx context.Context, client *mongo.Client) (ConnectionStatus, error) { return runCommand[ConnectionStatus](ctx, client, "connectionStatus") } @@ -64,21 +64,21 @@ type HelloResponse struct { Hosts []string `bson:"hosts,omitempty"` } -func GetHelloResponse(ctx context.Context, client *mongo.Client) (*HelloResponse, error) { +func GetHelloResponse(ctx context.Context, client *mongo.Client) (HelloResponse, error) { return runCommand[HelloResponse](ctx, client, "hello") } -func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (*T, error) { +func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (T, error) { + var result T singleResult := client.Database("admin").RunCommand(ctx, bson.D{ bson.E{Key: command, Value: 1}, }) if singleResult.Err() != nil { - return nil, fmt.Errorf("'%s' failed: %v", command, singleResult.Err()) + return result, fmt.Errorf("'%s' failed: %v", command, singleResult.Err()) } - var result T if err := singleResult.Decode(&result); err != nil { - return nil, fmt.Errorf("'%s' failed: %v", command, err) + return result, fmt.Errorf("'%s' failed: %v", command, err) } - return &result, nil + return result, nil } diff --git a/flow/shared/mongo/validation.go b/flow/shared/mongo/validation.go index 7a1d96d7ce..59318d84f3 100644 --- a/flow/shared/mongo/validation.go +++ b/flow/shared/mongo/validation.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -16,6 +17,8 @@ const ( ShardedCluster = "ShardedCluster" ) +var RequiredRoles = [...]string{"readAnyDatabase", "clusterMonitor"} + func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error { buildInfo, err := GetBuildInfo(ctx, client) if err != nil { @@ -54,24 +57,15 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro } func ValidateUserRoles(ctx context.Context, client *mongo.Client) error { - RequiredRoles := []string{"readAnyDatabase", "clusterMonitor"} - connectionStatus, err := GetConnectionStatus(ctx, client) if err != nil { return err } - hasRole := func(roles []Role, targetRole string) bool { - for _, role := range roles { - if role.Role == targetRole { - return true - } - } - return false - } - for _, requiredRole := range RequiredRoles { - if !hasRole(connectionStatus.AuthInfo.AuthenticatedUserRoles, requiredRole) { + if !slices.ContainsFunc(connectionStatus.AuthInfo.AuthenticatedUserRoles, func(r Role) bool { + return r.Role == requiredRole + }) { return fmt.Errorf("missing required role: %s", requiredRole) } }