diff --git a/ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins.go b/ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins.go index 7f8cc178531..939d97eec68 100644 --- a/ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins.go +++ b/ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins.go @@ -62,7 +62,7 @@ func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGro }, }, Spec: v1alpha1.PodGroupSpec{ - MinMember: utils.CalculateDesiredReplicas(ctx, app) + 1, // +1 for the head pod + MinMember: utils.CalculateDesiredReplicas(app) + 1, // +1 for the head pod MinResources: utils.CalculateDesiredResources(app), }, } diff --git a/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler_test.go b/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler_test.go index d137c01f76c..ce09baa4f0e 100644 --- a/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler_test.go +++ b/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler_test.go @@ -1,7 +1,6 @@ package volcano import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -159,7 +158,7 @@ func TestCreatePodGroupForRayCluster(t *testing.T) { cluster := createTestRayCluster(1) - minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1 + minMember := utils.CalculateDesiredReplicas(&cluster) + 1 totalResource := utils.CalculateDesiredResources(&cluster) pg := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource) @@ -183,7 +182,7 @@ func TestCreatePodGroupForRayCluster_NumOfHosts2(t *testing.T) { cluster := createTestRayCluster(2) - minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1 + minMember := utils.CalculateDesiredReplicas(&cluster) + 1 totalResource := utils.CalculateDesiredResources(&cluster) pg := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource) diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 17d6616f039..86877b81520 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -642,7 +642,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv continue } // workerReplicas will store the target number of pods for this worker group. - numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, worker)) + numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(worker)) logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", numExpectedWorkerPods, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas) workerPods := corev1.PodList{} @@ -1169,7 +1169,7 @@ func (r *RayClusterReconciler) calculateStatus(ctx context.Context, instance *ra newInstance.Status.ReadyWorkerReplicas = utils.CalculateReadyReplicas(runtimePods) newInstance.Status.AvailableWorkerReplicas = utils.CalculateAvailableReplicas(runtimePods) - newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(ctx, newInstance) + newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(newInstance) newInstance.Status.MinWorkerReplicas = utils.CalculateMinReplicas(newInstance) newInstance.Status.MaxWorkerReplicas = utils.CalculateMaxReplicas(newInstance) diff --git a/ray-operator/controllers/ray/utils/util.go b/ray-operator/controllers/ray/utils/util.go index 3bb63f79189..0eeb970278a 100644 --- a/ray-operator/controllers/ray/utils/util.go +++ b/ray-operator/controllers/ray/utils/util.go @@ -335,23 +335,16 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string { return fmt.Sprintf("%s-%s", clusterName, nodeType) } -func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.WorkerGroupSpec) int32 { - log := ctrl.LoggerFrom(ctx) +func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 { // Always adhere to min/max replicas constraints. var workerReplicas int32 if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend { return 0 } - if *workerGroupSpec.MinReplicas > *workerGroupSpec.MaxReplicas { - log.Info("minReplicas is greater than maxReplicas, using maxReplicas as desired replicas. "+ - "Please fix this to avoid any unexpected behaviors.", "minReplicas", *workerGroupSpec.MinReplicas, "maxReplicas", *workerGroupSpec.MaxReplicas) - workerReplicas = *workerGroupSpec.MaxReplicas - } else if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas { - // Replicas is impossible to be nil as it has a default value assigned in the CRD. - // Add this check to make testing easier. + // Validation for replicas/min/max should be enforced in validation.go before reconcile proceeds. + // Here we only compute the desired replicas within the already-validated bounds. + if workerGroupSpec.Replicas == nil { workerReplicas = *workerGroupSpec.MinReplicas - } else if *workerGroupSpec.Replicas > *workerGroupSpec.MaxReplicas { - workerReplicas = *workerGroupSpec.MaxReplicas } else { workerReplicas = *workerGroupSpec.Replicas } @@ -359,10 +352,10 @@ func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.Wo } // CalculateDesiredReplicas calculate desired worker replicas at the cluster level -func CalculateDesiredReplicas(ctx context.Context, cluster *rayv1.RayCluster) int32 { +func CalculateDesiredReplicas(cluster *rayv1.RayCluster) int32 { count := int32(0) for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs { - count += GetWorkerGroupDesiredReplicas(ctx, nodeGroup) + count += GetWorkerGroupDesiredReplicas(nodeGroup) } return count diff --git a/ray-operator/controllers/ray/utils/util_test.go b/ray-operator/controllers/ray/utils/util_test.go index 851e37af3ea..15ef9f21b1f 100644 --- a/ray-operator/controllers/ray/utils/util_test.go +++ b/ray-operator/controllers/ray/utils/util_test.go @@ -550,7 +550,6 @@ func TestGenerateHeadServiceName(t *testing.T) { } func TestGetWorkerGroupDesiredReplicas(t *testing.T) { - ctx := context.Background() // Test 1: `WorkerGroupSpec.Replicas` is nil. // `Replicas` is impossible to be nil in a real RayCluster CR as it has a default value assigned in the CRD. numOfHosts := int32(1) @@ -562,37 +561,21 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) { MinReplicas: &minReplicas, MaxReplicas: &maxReplicas, } - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas) + assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas) // Test 2: `WorkerGroupSpec.Replicas` is not nil and is within the range. replicas := int32(3) workerGroupSpec.Replicas = &replicas - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas) + assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas) - // Test 3: `WorkerGroupSpec.Replicas` is not nil but is more than maxReplicas. - replicas = int32(6) - workerGroupSpec.Replicas = &replicas - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), maxReplicas) - - // Test 4: `WorkerGroupSpec.Replicas` is not nil but is less than minReplicas. - replicas = int32(0) - workerGroupSpec.Replicas = &replicas - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas) - - // Test 5: `WorkerGroupSpec.Replicas` is nil and minReplicas is less than maxReplicas. - workerGroupSpec.Replicas = nil - workerGroupSpec.MinReplicas = &maxReplicas - workerGroupSpec.MaxReplicas = &minReplicas - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), *workerGroupSpec.MaxReplicas) - - // Test 6: `WorkerGroupSpec.Suspend` is true. + // Test 3: `WorkerGroupSpec.Suspend` is true. suspend := true workerGroupSpec.MinReplicas = &maxReplicas workerGroupSpec.MaxReplicas = &minReplicas workerGroupSpec.Suspend = &suspend - assert.Zero(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec)) + assert.Zero(t, GetWorkerGroupDesiredReplicas(workerGroupSpec)) - // Test 7: `WorkerGroupSpec.NumOfHosts` is 4. + // Test 4: `WorkerGroupSpec.NumOfHosts` is 4. numOfHosts = int32(4) replicas = int32(5) suspend = false @@ -601,7 +584,7 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) { workerGroupSpec.Suspend = &suspend workerGroupSpec.MinReplicas = &minReplicas workerGroupSpec.MaxReplicas = &maxReplicas - assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas*numOfHosts) + assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas*numOfHosts) } func TestCalculateMinAndMaxReplicas(t *testing.T) { @@ -798,7 +781,7 @@ func TestCalculateDesiredReplicas(t *testing.T) { }, }, } - assert.Equal(t, CalculateDesiredReplicas(context.Background(), &cluster), tc.answer) + assert.Equal(t, CalculateDesiredReplicas(&cluster), tc.answer) }) } } diff --git a/ray-operator/controllers/ray/utils/validation.go b/ray-operator/controllers/ray/utils/validation.go index edda0b772d5..d311737cf17 100644 --- a/ray-operator/controllers/ray/utils/validation.go +++ b/ray-operator/controllers/ray/utils/validation.go @@ -39,10 +39,34 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s return fmt.Errorf("headGroupSpec should have at least one container") } + // Check if autoscaling is enabled once to avoid repeated calls + isAutoscalingEnabled := IsAutoscalingEnabled(spec) + for _, workerGroup := range spec.WorkerGroupSpecs { if len(workerGroup.Template.Spec.Containers) == 0 { return fmt.Errorf("workerGroupSpec should have at least one container") } + // When autoscaling is enabled, MinReplicas and MaxReplicas are optional + // as users can manually update them and the autoscaler will handle the adjustment. + if !isAutoscalingEnabled && (workerGroup.MinReplicas == nil || workerGroup.MaxReplicas == nil) { + return fmt.Errorf("worker group %s must set both minReplicas and maxReplicas when autoscaling is disabled", workerGroup.GroupName) + } + if workerGroup.MinReplicas != nil && *workerGroup.MinReplicas < 0 { + return fmt.Errorf("worker group %s has negative minReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas) + } + if workerGroup.MaxReplicas != nil && *workerGroup.MaxReplicas < 0 { + return fmt.Errorf("worker group %s has negative maxReplicas %d", workerGroup.GroupName, *workerGroup.MaxReplicas) + } + // When autoscaling is enabled, the Ray Autoscaler will manage replicas and + // eventually adjust them to fall within minReplicas/maxReplicas bounds. + if workerGroup.Replicas != nil && !isAutoscalingEnabled && workerGroup.MinReplicas != nil && workerGroup.MaxReplicas != nil { + if *workerGroup.Replicas < *workerGroup.MinReplicas { + return fmt.Errorf("worker group %s has replicas %d smaller than minReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MinReplicas) + } + if *workerGroup.Replicas > *workerGroup.MaxReplicas { + return fmt.Errorf("worker group %s has replicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.Replicas, *workerGroup.MaxReplicas) + } + } } if annotations[RayFTEnabledAnnotationKey] != "" && spec.GcsFaultToleranceOptions != nil { @@ -93,9 +117,6 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s } } - // Check if autoscaling is enabled once to avoid repeated calls - isAutoscalingEnabled := IsAutoscalingEnabled(spec) - // Validate that RAY_enable_autoscaler_v2 environment variable is not set to "1" or "true" when autoscaler is disabled if !isAutoscalingEnabled { if envVar, exists := EnvVarByName(RAY_ENABLE_AUTOSCALER_V2, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env); exists {