diff --git a/flow/connectors/mongo/validate.go b/flow/connectors/mongo/validate.go index bc5410433c..989fd11215 100644 --- a/flow/connectors/mongo/validate.go +++ b/flow/connectors/mongo/validate.go @@ -2,25 +2,20 @@ package connmongo import ( "context" - "errors" - "fmt" "github.com/PeerDB-io/peerdb/flow/generated/protos" shared_mongo "github.com/PeerDB-io/peerdb/flow/shared/mongo" ) func (c *MongoConnector) ValidateCheck(ctx context.Context) error { - version, err := c.GetVersion(ctx) - if err != nil { + if err := shared_mongo.ValidateServerCompatibility(ctx, c.client); err != nil { return err } - cmp, err := shared_mongo.CompareServerVersions(version, shared_mongo.MinSupportedVersion) - if err != nil { + + if err := shared_mongo.ValidateUserRoles(ctx, c.client); err != nil { return err } - if cmp == -1 { - return fmt.Errorf("require minimum mongo version %s", shared_mongo.MinSupportedVersion) - } + return nil } @@ -29,21 +24,9 @@ func (c *MongoConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F return nil } - if _, err := shared_mongo.GetReplSetGetStatus(ctx, c.client); err != nil { - return err - } - - serverStatus, err := shared_mongo.GetServerStatus(ctx, c.client) - if err != nil { + if err := shared_mongo.ValidateOplogRetention(ctx, c.client); err != nil { return err } - if serverStatus.StorageEngine.Name != "wiredTiger" { - return errors.New("storage engine must be 'wiredTiger'") - } - if serverStatus.OplogTruncation.OplogMinRetentionHours == 0 || - serverStatus.OplogTruncation.OplogMinRetentionHours < shared_mongo.MinOplogRetentionHours { - return errors.New("oplog retention must be set to >= 24 hours") - } return nil } diff --git a/flow/shared/mongo/commands.go b/flow/shared/mongo/commands.go new file mode 100644 index 0000000000..a48b0dd82f --- /dev/null +++ b/flow/shared/mongo/commands.go @@ -0,0 +1,84 @@ +package mongo + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +type BuildInfo struct { + Version string `bson:"version"` +} + +func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) { + return runCommand[BuildInfo](ctx, client, "buildInfo") +} + +type ReplSetGetStatus struct { + Set string `bson:"set"` + MyState int `bson:"myState"` +} + +func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (ReplSetGetStatus, error) { + return runCommand[ReplSetGetStatus](ctx, client, "replSetGetStatus") +} + +type OplogTruncation struct { + OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"` +} + +type StorageEngine struct { + Name string `bson:"name"` +} + +type ServerStatus struct { + StorageEngine StorageEngine `bson:"storageEngine"` + OplogTruncation OplogTruncation `bson:"oplogTruncation"` +} + +func GetServerStatus(ctx context.Context, client *mongo.Client) (ServerStatus, error) { + return runCommand[ServerStatus](ctx, client, "serverStatus") +} + +type ConnectionStatus struct { + AuthInfo AuthInfo `bson:"authInfo"` +} + +type AuthInfo struct { + AuthenticatedUserRoles []Role `bson:"authenticatedUserRoles"` +} + +type Role struct { + Role string `bson:"role"` + DB string `bson:"db"` +} + +func GetConnectionStatus(ctx context.Context, client *mongo.Client) (ConnectionStatus, error) { + return runCommand[ConnectionStatus](ctx, client, "connectionStatus") +} + +type HelloResponse struct { + Msg string `bson:"msg,omitempty"` + Hosts []string `bson:"hosts,omitempty"` +} + +func GetHelloResponse(ctx context.Context, client *mongo.Client) (HelloResponse, error) { + return runCommand[HelloResponse](ctx, client, "hello") +} + +func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (T, error) { + var result T + singleResult := client.Database("admin").RunCommand(ctx, bson.D{ + bson.E{Key: command, Value: 1}, + }) + if singleResult.Err() != nil { + return result, fmt.Errorf("'%s' failed: %v", command, singleResult.Err()) + } + + if err := singleResult.Decode(&result); err != nil { + return result, fmt.Errorf("'%s' failed: %v", command, err) + } + return result, nil +} diff --git a/flow/shared/mongo/validation.go b/flow/shared/mongo/validation.go index 752956ca07..59318d84f3 100644 --- a/flow/shared/mongo/validation.go +++ b/flow/shared/mongo/validation.go @@ -2,75 +2,119 @@ package mongo import ( "context" + "errors" "fmt" + "slices" - "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" ) const ( MinSupportedVersion = "5.1.0" MinOplogRetentionHours = 24 + + ReplicaSet = "ReplicaSet" + ShardedCluster = "ShardedCluster" ) -type BuildInfo struct { - Version string `bson:"version"` -} +var RequiredRoles = [...]string{"readAnyDatabase", "clusterMonitor"} -type ReplSetGetStatus struct { - Set string `bson:"set"` - MyState int `bson:"myState"` -} +func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error { + buildInfo, err := GetBuildInfo(ctx, client) + if err != nil { + return err + } -type OplogTruncation struct { - OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"` -} + if cmp, err := CompareServerVersions(buildInfo.Version, MinSupportedVersion); err != nil { + return err + } else if cmp < 0 { + return fmt.Errorf("require minimum mongo version %s", MinSupportedVersion) + } -type StorageEngine struct { - Name string `bson:"name"` -} + validateStorageEngine := func(instanceCtx context.Context, instanceClient *mongo.Client) error { + ss, err := GetServerStatus(instanceCtx, instanceClient) + if err != nil { + return err + } -type ServerStatus struct { - StorageEngine StorageEngine `bson:"storageEngine"` - OplogTruncation OplogTruncation `bson:"oplogTruncation"` + if ss.StorageEngine.Name != "wiredTiger" { + return errors.New("only wiredTiger storage engine is supported") + } + return nil + } + + topologyType, err := GetTopologyType(ctx, client) + if err != nil { + return err + } + + if topologyType == ReplicaSet { + return validateStorageEngine(ctx, client) + } else { + // TODO: run validation on shard + return nil + } } -func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{bson.E{Key: "buildInfo", Value: 1}}) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'buildInfo' command: %w", singleResult.Err()) +func ValidateUserRoles(ctx context.Context, client *mongo.Client) error { + connectionStatus, err := GetConnectionStatus(ctx, client) + if err != nil { + return err } - var info BuildInfo - if err := singleResult.Decode(&info); err != nil { - return nil, fmt.Errorf("failed to decode BuildInfo: %w", err) + + for _, requiredRole := range RequiredRoles { + if !slices.ContainsFunc(connectionStatus.AuthInfo.AuthenticatedUserRoles, func(r Role) bool { + return r.Role == requiredRole + }) { + return fmt.Errorf("missing required role: %s", requiredRole) + } } - return &info, nil + + return nil } -func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{ - bson.E{Key: "replSetGetStatus", Value: 1}, - }) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'replSetGetStatus' command: %w", singleResult.Err()) +func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error { + validateOplogRetention := func(instanceCtx context.Context, instanceClient *mongo.Client) error { + ss, err := GetServerStatus(instanceCtx, instanceClient) + if err != nil { + return err + } + if ss.OplogTruncation.OplogMinRetentionHours == 0 || + ss.OplogTruncation.OplogMinRetentionHours < MinOplogRetentionHours { + return fmt.Errorf("oplog retention must be set to >= 24 hours, but got %f", + ss.OplogTruncation.OplogMinRetentionHours) + } + return nil } - var status ReplSetGetStatus - if err := singleResult.Decode(&status); err != nil { - return nil, fmt.Errorf("failed to decode ReplSetGetStatus: %w", err) + + topology, err := GetTopologyType(ctx, client) + if err != nil { + return err + } + if topology == ReplicaSet { + return validateOplogRetention(ctx, client) + } else { + // TODO: run validation on shard + return nil } - return &status, nil } -func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) { - singleResult := client.Database("admin").RunCommand(ctx, bson.D{ - bson.E{Key: "serverStatus", Value: 1}, - }) - if singleResult.Err() != nil { - return nil, fmt.Errorf("failed to run 'serverStatus' command: %w", singleResult.Err()) +func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error) { + hello, err := GetHelloResponse(ctx, client) + if err != nil { + return "", err + } + + // Only replica set has 'hosts' field + // https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.hosts + if len(hello.Hosts) > 0 { + return ReplicaSet, nil } - var status ServerStatus - if err := singleResult.Decode(&status); err != nil { - return nil, fmt.Errorf("failed to decode ServerStatus: %w", err) + + // Only sharded cluster has 'msg' field, and equals to 'isdbgrid' + // https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.msg + if hello.Msg == "isdbgrid" { + return ShardedCluster, nil } - return &status, nil + return "", errors.New("topology type must be ReplicaSet or ShardedCluster") }