Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions flow/connectors/mongo/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@ package connmongo

import (
"context"
"errors"
"fmt"

"github.com/PeerDB-io/peerdb/flow/generated/protos"
shared_mongo "github.com/PeerDB-io/peerdb/flow/shared/mongo"
)

func (c *MongoConnector) ValidateCheck(ctx context.Context) error {
version, err := c.GetVersion(ctx)
if err != nil {
if err := shared_mongo.ValidateServerCompatibility(ctx, c.client); err != nil {
return err
}
cmp, err := shared_mongo.CompareServerVersions(version, shared_mongo.MinSupportedVersion)
if err != nil {

if err := shared_mongo.ValidateUserRoles(ctx, c.client); err != nil {
return err
}
if cmp == -1 {
return fmt.Errorf("require minimum mongo version %s", shared_mongo.MinSupportedVersion)
}

return nil
}

Expand All @@ -29,21 +24,9 @@ func (c *MongoConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F
return nil
}

if _, err := shared_mongo.GetReplSetGetStatus(ctx, c.client); err != nil {
return err
}

serverStatus, err := shared_mongo.GetServerStatus(ctx, c.client)
if err != nil {
if err := shared_mongo.ValidateOplogRetention(ctx, c.client); err != nil {
return err
}
if serverStatus.StorageEngine.Name != "wiredTiger" {
return errors.New("storage engine must be 'wiredTiger'")
}
if serverStatus.OplogTruncation.OplogMinRetentionHours == 0 ||
serverStatus.OplogTruncation.OplogMinRetentionHours < shared_mongo.MinOplogRetentionHours {
return errors.New("oplog retention must be set to >= 24 hours")
}

return nil
}
84 changes: 84 additions & 0 deletions flow/shared/mongo/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package mongo

import (
"context"
"fmt"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
)

type BuildInfo struct {
Version string `bson:"version"`
}

func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) {
return runCommand[BuildInfo](ctx, client, "buildInfo")
}

type ReplSetGetStatus struct {
Set string `bson:"set"`
MyState int `bson:"myState"`
}

func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (ReplSetGetStatus, error) {
return runCommand[ReplSetGetStatus](ctx, client, "replSetGetStatus")
}

type OplogTruncation struct {
OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"`
}

type StorageEngine struct {
Name string `bson:"name"`
}

type ServerStatus struct {
StorageEngine StorageEngine `bson:"storageEngine"`
OplogTruncation OplogTruncation `bson:"oplogTruncation"`
}

func GetServerStatus(ctx context.Context, client *mongo.Client) (ServerStatus, error) {
return runCommand[ServerStatus](ctx, client, "serverStatus")
}

type ConnectionStatus struct {
AuthInfo AuthInfo `bson:"authInfo"`
}

type AuthInfo struct {
AuthenticatedUserRoles []Role `bson:"authenticatedUserRoles"`
}

type Role struct {
Role string `bson:"role"`
DB string `bson:"db"`
}

func GetConnectionStatus(ctx context.Context, client *mongo.Client) (ConnectionStatus, error) {
return runCommand[ConnectionStatus](ctx, client, "connectionStatus")
}

type HelloResponse struct {
Msg string `bson:"msg,omitempty"`
Hosts []string `bson:"hosts,omitempty"`
}

func GetHelloResponse(ctx context.Context, client *mongo.Client) (HelloResponse, error) {
return runCommand[HelloResponse](ctx, client, "hello")
}

func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (T, error) {
var result T
singleResult := client.Database("admin").RunCommand(ctx, bson.D{
bson.E{Key: command, Value: 1},
})
if singleResult.Err() != nil {
return result, fmt.Errorf("'%s' failed: %v", command, singleResult.Err())
}

if err := singleResult.Decode(&result); err != nil {
return result, fmt.Errorf("'%s' failed: %v", command, err)
}
return result, nil
}
134 changes: 89 additions & 45 deletions flow/shared/mongo/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,119 @@ package mongo

import (
"context"
"errors"
"fmt"
"slices"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
)

const (
MinSupportedVersion = "5.1.0"
MinOplogRetentionHours = 24

ReplicaSet = "ReplicaSet"
ShardedCluster = "ShardedCluster"
)

type BuildInfo struct {
Version string `bson:"version"`
}
var RequiredRoles = [...]string{"readAnyDatabase", "clusterMonitor"}

type ReplSetGetStatus struct {
Set string `bson:"set"`
MyState int `bson:"myState"`
}
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error {
buildInfo, err := GetBuildInfo(ctx, client)
if err != nil {
return err
}

type OplogTruncation struct {
OplogMinRetentionHours float64 `bson:"oplogMinRetentionHours"`
}
if cmp, err := CompareServerVersions(buildInfo.Version, MinSupportedVersion); err != nil {
return err
} else if cmp < 0 {
return fmt.Errorf("require minimum mongo version %s", MinSupportedVersion)
}

type StorageEngine struct {
Name string `bson:"name"`
}
validateStorageEngine := func(instanceCtx context.Context, instanceClient *mongo.Client) error {
ss, err := GetServerStatus(instanceCtx, instanceClient)
if err != nil {
return err
}

type ServerStatus struct {
StorageEngine StorageEngine `bson:"storageEngine"`
OplogTruncation OplogTruncation `bson:"oplogTruncation"`
if ss.StorageEngine.Name != "wiredTiger" {
return errors.New("only wiredTiger storage engine is supported")
}
return nil
}

topologyType, err := GetTopologyType(ctx, client)
if err != nil {
return err
}

if topologyType == ReplicaSet {
return validateStorageEngine(ctx, client)
} else {
// TODO: run validation on shard
return nil
}
}

func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) {
singleResult := client.Database("admin").RunCommand(ctx, bson.D{bson.E{Key: "buildInfo", Value: 1}})
if singleResult.Err() != nil {
return nil, fmt.Errorf("failed to run 'buildInfo' command: %w", singleResult.Err())
func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
connectionStatus, err := GetConnectionStatus(ctx, client)
if err != nil {
return err
}
var info BuildInfo
if err := singleResult.Decode(&info); err != nil {
return nil, fmt.Errorf("failed to decode BuildInfo: %w", err)

for _, requiredRole := range RequiredRoles {
if !slices.ContainsFunc(connectionStatus.AuthInfo.AuthenticatedUserRoles, func(r Role) bool {
return r.Role == requiredRole
}) {
return fmt.Errorf("missing required role: %s", requiredRole)
}
}
return &info, nil

return nil
}

func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) {
singleResult := client.Database("admin").RunCommand(ctx, bson.D{
bson.E{Key: "replSetGetStatus", Value: 1},
})
if singleResult.Err() != nil {
return nil, fmt.Errorf("failed to run 'replSetGetStatus' command: %w", singleResult.Err())
func ValidateOplogRetention(ctx context.Context, client *mongo.Client) error {
validateOplogRetention := func(instanceCtx context.Context, instanceClient *mongo.Client) error {
ss, err := GetServerStatus(instanceCtx, instanceClient)
if err != nil {
return err
}
if ss.OplogTruncation.OplogMinRetentionHours == 0 ||
ss.OplogTruncation.OplogMinRetentionHours < MinOplogRetentionHours {
return fmt.Errorf("oplog retention must be set to >= 24 hours, but got %f",
ss.OplogTruncation.OplogMinRetentionHours)
}
return nil
}
var status ReplSetGetStatus
if err := singleResult.Decode(&status); err != nil {
return nil, fmt.Errorf("failed to decode ReplSetGetStatus: %w", err)

topology, err := GetTopologyType(ctx, client)
if err != nil {
return err
}
if topology == ReplicaSet {
return validateOplogRetention(ctx, client)
} else {
// TODO: run validation on shard
return nil
}
return &status, nil
}

func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) {
singleResult := client.Database("admin").RunCommand(ctx, bson.D{
bson.E{Key: "serverStatus", Value: 1},
})
if singleResult.Err() != nil {
return nil, fmt.Errorf("failed to run 'serverStatus' command: %w", singleResult.Err())
func GetTopologyType(ctx context.Context, client *mongo.Client) (string, error) {
hello, err := GetHelloResponse(ctx, client)
if err != nil {
return "", err
}

// Only replica set has 'hosts' field
// https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.hosts
if len(hello.Hosts) > 0 {
return ReplicaSet, nil
}
var status ServerStatus
if err := singleResult.Decode(&status); err != nil {
return nil, fmt.Errorf("failed to decode ServerStatus: %w", err)

// Only sharded cluster has 'msg' field, and equals to 'isdbgrid'
// https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-data-hello.msg
if hello.Msg == "isdbgrid" {
return ShardedCluster, nil
}
return &status, nil
return "", errors.New("topology type must be ReplicaSet or ShardedCluster")
}
Loading