Skip to content

Commit c90ba62

Browse files
committed
RHOAIENG-39073: Add priority class support
1 parent d202cd4 commit c90ba62

File tree

3 files changed

+355
-23
lines changed

3 files changed

+355
-23
lines changed

src/codeflare_sdk/common/kueue/kueue.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
from typing import Optional, List
16+
import logging
1617
from codeflare_sdk.common import _kube_api_error_handling
1718
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
1819
from kubernetes import client
1920
from kubernetes.client.exceptions import ApiException
2021

2122
from ...common.utils import get_current_namespace
2223

24+
logger = logging.getLogger(__name__)
25+
2326

2427
def get_default_kueue_name(namespace: str) -> Optional[str]:
2528
"""
@@ -144,6 +147,59 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
144147
return False
145148

146149

150+
def priority_class_exists(priority_class_name: str) -> Optional[bool]:
151+
"""
152+
Checks if a WorkloadPriorityClass with the provided name exists in the cluster.
153+
154+
WorkloadPriorityClass is a cluster-scoped resource.
155+
156+
Args:
157+
priority_class_name (str):
158+
The name of the WorkloadPriorityClass to check for existence.
159+
160+
Returns:
161+
Optional[bool]:
162+
True if the WorkloadPriorityClass exists, False if it doesn't exist,
163+
None if we cannot verify (e.g., permission denied).
164+
"""
165+
try:
166+
config_check()
167+
api_instance = client.CustomObjectsApi(get_api_client())
168+
# Try to get the specific WorkloadPriorityClass by name
169+
api_instance.get_cluster_custom_object(
170+
group="kueue.x-k8s.io",
171+
version="v1beta1",
172+
plural="workloadpriorityclasses",
173+
name=priority_class_name,
174+
)
175+
return True
176+
except client.ApiException as e:
177+
if e.status == 404:
178+
# Not found - doesn't exist
179+
return False
180+
elif e.status == 403:
181+
# Permission denied - can't verify, return None
182+
logger.warning(
183+
f"Permission denied when checking WorkloadPriorityClass '{priority_class_name}'. "
184+
f"Cannot verify if it exists."
185+
)
186+
return None
187+
else:
188+
# Other API errors - log and return None (best effort)
189+
logger.warning(
190+
f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. "
191+
f"Cannot verify if it exists."
192+
)
193+
return None
194+
except Exception as e: # pragma: no cover
195+
# Unexpected errors - log and return None (best effort)
196+
logger.warning(
197+
f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. "
198+
f"Cannot verify if it exists."
199+
)
200+
return None
201+
202+
147203
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
148204
"""
149205
Adds a local queue name label to the provided item.

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from typing import Dict, Any, Optional, Tuple, Union
2424

2525
from ray.runtime_env import RuntimeEnv
26-
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
26+
from codeflare_sdk.common.kueue.kueue import (
27+
get_default_kueue_name,
28+
priority_class_exists,
29+
)
2730
from codeflare_sdk.common.utils.constants import MOUNT_PATH
2831

2932
from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version
@@ -69,6 +72,7 @@ def __init__(
6972
ttl_seconds_after_finished: int = 0,
7073
active_deadline_seconds: Optional[int] = None,
7174
local_queue: Optional[str] = None,
75+
priority_class: Optional[str] = None,
7276
):
7377
"""
7478
Initialize a RayJob instance.
@@ -86,11 +90,13 @@ def __init__(
8690
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
8791
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
8892
local_queue: The Kueue LocalQueue to submit the job to (optional)
93+
priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional).
8994
9095
Note:
9196
- True if cluster_config is provided (new cluster will be cleaned up)
9297
- False if cluster_name is provided (existing cluster will not be shut down)
9398
- User can explicitly set this value to override auto-detection
99+
- Kueue labels (queue and priority) can be applied to both new and existing clusters
94100
"""
95101
if cluster_name is None and cluster_config is None:
96102
raise ValueError(
@@ -124,6 +130,7 @@ def __init__(
124130
self.ttl_seconds_after_finished = ttl_seconds_after_finished
125131
self.active_deadline_seconds = active_deadline_seconds
126132
self.local_queue = local_queue
133+
self.priority_class = priority_class
127134

128135
if namespace is None:
129136
detected_namespace = get_current_namespace()
@@ -165,6 +172,7 @@ def submit(self) -> str:
165172
# Validate configuration before submitting
166173
self._validate_ray_version_compatibility()
167174
self._validate_working_dir_entrypoint()
175+
self._validate_priority_class()
168176

169177
# Extract files from entrypoint and runtime_env working_dir
170178
files = extract_all_local_files(self)
@@ -243,26 +251,35 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
243251
# Extract files once and use for both runtime_env and submitter pod
244252
files = extract_all_local_files(self)
245253

254+
# Build Kueue labels and annotations for all jobs (new and existing clusters)
246255
labels = {}
247-
# If cluster_config is provided, use the local_queue from the cluster_config
248-
if self._cluster_config is not None:
249-
if self.local_queue:
250-
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
256+
257+
# Queue name label - apply to all jobs when explicitly specified
258+
# For new clusters, also auto-detect default queue if not specified
259+
if self.local_queue:
260+
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
261+
elif self._cluster_config is not None:
262+
# Only auto-detect default queue for new clusters
263+
default_queue = get_default_kueue_name(self.namespace)
264+
if default_queue:
265+
labels["kueue.x-k8s.io/queue-name"] = default_queue
251266
else:
252-
default_queue = get_default_kueue_name(self.namespace)
253-
if default_queue:
254-
labels["kueue.x-k8s.io/queue-name"] = default_queue
255-
else:
256-
# No default queue found, use "default" as fallback
257-
labels["kueue.x-k8s.io/queue-name"] = "default"
258-
logger.warning(
259-
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
260-
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
261-
f"does not exist, the RayJob submission will fail. "
262-
f"To fix this, please explicitly specify the 'local_queue' parameter."
263-
)
267+
# No default queue found, use "default" as fallback
268+
labels["kueue.x-k8s.io/queue-name"] = "default"
269+
logger.warning(
270+
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
271+
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
272+
f"does not exist, the RayJob submission will fail. "
273+
f"To fix this, please explicitly specify the 'local_queue' parameter."
274+
)
275+
276+
# Priority class label - apply when specified
277+
if self.priority_class:
278+
labels["kueue.x-k8s.io/priority-class"] = self.priority_class
264279

265-
rayjob_cr["metadata"]["labels"] = labels
280+
# Apply labels to metadata
281+
if labels:
282+
rayjob_cr["metadata"]["labels"] = labels
266283

267284
# When using Kueue (queue label present), start with suspend=true
268285
# Kueue will unsuspend the job once the workload is admitted
@@ -450,6 +467,36 @@ def _validate_cluster_config_image(self):
450467
elif is_warning:
451468
warnings.warn(f"Cluster config image: {message}")
452469

470+
def _validate_priority_class(self):
471+
"""
472+
Validate that the priority class exists in the cluster (best effort).
473+
474+
Raises ValueError if the priority class is definitively known not to exist.
475+
If we cannot verify (e.g., permission denied), logs a warning and allows submission.
476+
"""
477+
if self.priority_class:
478+
logger.debug(f"Validating priority class '{self.priority_class}'...")
479+
exists = priority_class_exists(self.priority_class)
480+
481+
if exists is False:
482+
# Definitively doesn't exist - fail validation
483+
print(
484+
f"❌ Priority class '{self.priority_class}' does not exist in the cluster. "
485+
f"Submission cancelled."
486+
)
487+
raise ValueError(
488+
f"Priority class '{self.priority_class}' does not exist"
489+
)
490+
elif exists is None:
491+
# Cannot verify - log warning and allow submission
492+
logger.warning(
493+
f"Could not verify if priority class '{self.priority_class}' exists. "
494+
f"Proceeding with submission - Kueue will validate on admission."
495+
)
496+
else:
497+
# exists is True - validation passed
498+
logger.debug(f"Priority class '{self.priority_class}' verified.")
499+
453500
def _validate_working_dir_entrypoint(self):
454501
"""
455502
Validate entrypoint file configuration.

0 commit comments

Comments
 (0)