Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGro
},
},
Spec: v1alpha1.PodGroupSpec{
MinMember: utils.CalculateDesiredReplicas(ctx, app) + 1, // +1 for the head pod
MinMember: utils.CalculateDesiredReplicas(app) + 1, // +1 for the head pod
MinResources: utils.CalculateDesiredResources(app),
},
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package volcano

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -159,7 +158,7 @@ func TestCreatePodGroupForRayCluster(t *testing.T) {

cluster := createTestRayCluster(1)

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)

Expand All @@ -183,7 +182,7 @@ func TestCreatePodGroupForRayCluster_NumOfHosts2(t *testing.T) {

cluster := createTestRayCluster(2)

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)

Expand Down
4 changes: 2 additions & 2 deletions ray-operator/controllers/ray/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
continue
}
// workerReplicas will store the target number of pods for this worker group.
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, worker))
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(worker))
logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", numExpectedWorkerPods, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas)

workerPods := corev1.PodList{}
Expand Down Expand Up @@ -1169,7 +1169,7 @@ func (r *RayClusterReconciler) calculateStatus(ctx context.Context, instance *ra

newInstance.Status.ReadyWorkerReplicas = utils.CalculateReadyReplicas(runtimePods)
newInstance.Status.AvailableWorkerReplicas = utils.CalculateAvailableReplicas(runtimePods)
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(ctx, newInstance)
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(newInstance)
newInstance.Status.MinWorkerReplicas = utils.CalculateMinReplicas(newInstance)
newInstance.Status.MaxWorkerReplicas = utils.CalculateMaxReplicas(newInstance)

Expand Down
19 changes: 6 additions & 13 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,34 +335,27 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string {
return fmt.Sprintf("%s-%s", clusterName, nodeType)
}

func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.WorkerGroupSpec) int32 {
log := ctrl.LoggerFrom(ctx)
func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 {
// Always adhere to min/max replicas constraints.
var workerReplicas int32
if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend {
return 0
}
if *workerGroupSpec.MinReplicas > *workerGroupSpec.MaxReplicas {
log.Info("minReplicas is greater than maxReplicas, using maxReplicas as desired replicas. "+
"Please fix this to avoid any unexpected behaviors.", "minReplicas", *workerGroupSpec.MinReplicas, "maxReplicas", *workerGroupSpec.MaxReplicas)
workerReplicas = *workerGroupSpec.MaxReplicas
} else if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
// Replicas is impossible to be nil as it has a default value assigned in the CRD.
// Add this check to make testing easier.
// Validation for replicas/min/max should be enforced in validation.go before reconcile proceeds.
// Here we only compute the desired replicas within the already-validated bounds.
if workerGroupSpec.Replicas == nil {
workerReplicas = *workerGroupSpec.MinReplicas
} else if *workerGroupSpec.Replicas > *workerGroupSpec.MaxReplicas {
workerReplicas = *workerGroupSpec.MaxReplicas
} else {
workerReplicas = *workerGroupSpec.Replicas
}
return workerReplicas * workerGroupSpec.NumOfHosts
}

// CalculateDesiredReplicas calculate desired worker replicas at the cluster level
func CalculateDesiredReplicas(ctx context.Context, cluster *rayv1.RayCluster) int32 {
func CalculateDesiredReplicas(cluster *rayv1.RayCluster) int32 {
count := int32(0)
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
count += GetWorkerGroupDesiredReplicas(ctx, nodeGroup)
count += GetWorkerGroupDesiredReplicas(nodeGroup)
}

return count
Expand Down
31 changes: 7 additions & 24 deletions ray-operator/controllers/ray/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ func TestGenerateHeadServiceName(t *testing.T) {
}

func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
ctx := context.Background()
// Test 1: `WorkerGroupSpec.Replicas` is nil.
// `Replicas` is impossible to be nil in a real RayCluster CR as it has a default value assigned in the CRD.
numOfHosts := int32(1)
Expand All @@ -562,37 +561,21 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
}
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas)

// Test 2: `WorkerGroupSpec.Replicas` is not nil and is within the range.
replicas := int32(3)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas)

// Test 3: `WorkerGroupSpec.Replicas` is not nil but is more than maxReplicas.
replicas = int32(6)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), maxReplicas)

// Test 4: `WorkerGroupSpec.Replicas` is not nil but is less than minReplicas.
replicas = int32(0)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)

// Test 5: `WorkerGroupSpec.Replicas` is nil and minReplicas is less than maxReplicas.
workerGroupSpec.Replicas = nil
workerGroupSpec.MinReplicas = &maxReplicas
workerGroupSpec.MaxReplicas = &minReplicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), *workerGroupSpec.MaxReplicas)

// Test 6: `WorkerGroupSpec.Suspend` is true.
// Test 3: `WorkerGroupSpec.Suspend` is true.
suspend := true
workerGroupSpec.MinReplicas = &maxReplicas
workerGroupSpec.MaxReplicas = &minReplicas
workerGroupSpec.Suspend = &suspend
assert.Zero(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec))
assert.Zero(t, GetWorkerGroupDesiredReplicas(workerGroupSpec))

// Test 7: `WorkerGroupSpec.NumOfHosts` is 4.
// Test 4: `WorkerGroupSpec.NumOfHosts` is 4.
numOfHosts = int32(4)
replicas = int32(5)
suspend = false
Expand All @@ -601,7 +584,7 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
workerGroupSpec.Suspend = &suspend
workerGroupSpec.MinReplicas = &minReplicas
workerGroupSpec.MaxReplicas = &maxReplicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas*numOfHosts)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas*numOfHosts)
}

func TestCalculateMinAndMaxReplicas(t *testing.T) {
Expand Down Expand Up @@ -798,7 +781,7 @@ func TestCalculateDesiredReplicas(t *testing.T) {
},
},
}
assert.Equal(t, CalculateDesiredReplicas(context.Background(), &cluster), tc.answer)
assert.Equal(t, CalculateDesiredReplicas(&cluster), tc.answer)
})
}
}
Expand Down
27 changes: 24 additions & 3 deletions ray-operator/controllers/ray/utils/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,34 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
return fmt.Errorf("headGroupSpec should have at least one container")
}

// Check if autoscaling is enabled once to avoid repeated calls
isAutoscalingEnabled := IsAutoscalingEnabled(spec)

for _, workerGroup := range spec.WorkerGroupSpecs {
if len(workerGroup.Template.Spec.Containers) == 0 {
return fmt.Errorf("workerGroupSpec should have at least one container")
}
// When autoscaling is enabled, MinReplicas and MaxReplicas are optional
// as users can manually update them and the autoscaler will handle the adjustment.
if !isAutoscalingEnabled && (workerGroup.MinReplicas == nil || workerGroup.MaxReplicas == nil) {
return fmt.Errorf("worker group %s must set both minReplicas and maxReplicas when autoscaling is disabled", workerGroup.GroupName)
}
if workerGroup.MinReplicas != nil && *workerGroup.MinReplicas < 0 {
return fmt.Errorf("worker group %s has negative minReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas)
}
if workerGroup.MaxReplicas != nil && *workerGroup.MaxReplicas < 0 {
return fmt.Errorf("worker group %s has negative maxReplicas %d", workerGroup.GroupName, *workerGroup.MaxReplicas)
}
// When autoscaling is enabled, the Ray Autoscaler will manage replicas and
// eventually adjust them to fall within minReplicas/maxReplicas bounds.
if workerGroup.Replicas != nil && !isAutoscalingEnabled && workerGroup.MinReplicas != nil && workerGroup.MaxReplicas != nil {
if *workerGroup.Replicas < *workerGroup.MinReplicas {
return fmt.Errorf("worker group %s has replicas %d smaller than minReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MinReplicas)
}
if *workerGroup.Replicas > *workerGroup.MaxReplicas {
return fmt.Errorf("worker group %s has replicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MaxReplicas)
}
}
}

if annotations[RayFTEnabledAnnotationKey] != "" && spec.GcsFaultToleranceOptions != nil {
Expand Down Expand Up @@ -93,9 +117,6 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
}
}

// Check if autoscaling is enabled once to avoid repeated calls
isAutoscalingEnabled := IsAutoscalingEnabled(spec)

// Validate that RAY_enable_autoscaler_v2 environment variable is not set to "1" or "true" when autoscaler is disabled
if !isAutoscalingEnabled {
if envVar, exists := EnvVarByName(RAY_ENABLE_AUTOSCALER_V2, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env); exists {
Expand Down
Loading