40
40
41
41
KUBERNETES_DATETIME_FORMAT : Final [str ] = "%Y-%m-%dT%H:%M:%SZ"
42
42
43
- DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : Final [
44
- str
45
- ] = "kubernetes.dask.org/cooldown-until"
43
+ DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : Final [str ] = "kubernetes.dask.org/cooldown-until"
46
44
47
45
# Load operator plugins from other packages
48
46
PLUGINS : list [Any ] = []
@@ -59,20 +57,15 @@ def _get_annotations(meta: kopf.Meta) -> dict[str, str]:
59
57
return {
60
58
annotation_key : annotation_value
61
59
for annotation_key , annotation_value in meta .annotations .items ()
62
- if not any (
63
- annotation_key .startswith (namespace )
64
- for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE
65
- )
60
+ if not any (annotation_key .startswith (namespace ) for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE )
66
61
}
67
62
68
63
69
64
def _get_labels (meta : kopf .Meta ) -> dict [str , str ]:
70
65
return {
71
66
label_key : label_value
72
67
for label_key , label_value in meta .labels .items ()
73
- if not any (
74
- label_key .startswith (namespace ) for namespace in _LABEL_NAMESPACES_TO_IGNORE
75
- )
68
+ if not any (label_key .startswith (namespace ) for namespace in _LABEL_NAMESPACES_TO_IGNORE )
76
69
}
77
70
78
71
@@ -351,21 +344,15 @@ async def daskcluster_create_components(
351
344
annotations .update (** scheduler_spec ["metadata" ]["annotations" ])
352
345
if "labels" in scheduler_spec ["metadata" ]:
353
346
labels .update (** scheduler_spec ["metadata" ]["labels" ])
354
- data = build_scheduler_deployment_spec (
355
- name , scheduler_spec .get ("spec" ), annotations , labels
356
- )
347
+ data = build_scheduler_deployment_spec (name , scheduler_spec .get ("spec" ), annotations , labels )
357
348
kopf .adopt (data )
358
349
scheduler_deployment = await Deployment (data , namespace = namespace )
359
350
if not await scheduler_deployment .exists ():
360
351
await scheduler_deployment .create ()
361
- logger .info (
362
- f"Scheduler deployment { scheduler_deployment .name } created in { namespace } ."
363
- )
352
+ logger .info (f"Scheduler deployment { scheduler_deployment .name } created in { namespace } ." )
364
353
365
354
# Create scheduler service
366
- data = build_scheduler_service_spec (
367
- name , scheduler_spec .get ("service" ), annotations , labels
368
- )
355
+ data = build_scheduler_service_spec (name , scheduler_spec .get ("service" ), annotations , labels )
369
356
kopf .adopt (data )
370
357
scheduler_service = await Service (data , namespace = namespace )
371
358
if not await scheduler_service .exists ():
@@ -389,6 +376,92 @@ async def daskcluster_create_components(
389
376
390
377
patch .status ["phase" ] = "Pending"
391
378
379
+ @kopf .on .update ("daskcluster.kubernetes.dask.org" )
380
+ async def daskcluster_update (
381
+ spec : kopf .Spec ,
382
+ status : kopf .Status ,
383
+ meta : kopf .Meta ,
384
+ name : str | None ,
385
+ namespace : str | None ,
386
+ diff : kopf .Diff ,
387
+ patch : kopf .Patch ,
388
+ logger : kopf .Logger ,
389
+ ** __ : Any
390
+ ):
391
+ """When the DaskCluster resource is updated update all the components."""
392
+ assert name
393
+ assert namespace
394
+ logger .info (f"Handling update for DaskCluster '{ name } '" )
395
+
396
+ scheduler_changed = any (op ['path' ].startswith ('/spec/scheduler' ) for op in diff )
397
+ worker_changed = any (op ['path' ].startswith ('/spec/worker' ) for op in diff )
398
+
399
+ base_annotations = _get_annotations (meta )
400
+ base_labels = _get_labels (meta )
401
+
402
+ if scheduler_changed :
403
+ logger .info ("Scheduler spec changed, reconciling scheduler components." )
404
+ scheduler_spec_part = spec .get ("scheduler" , {})
405
+
406
+ scheduler_annotations = base_annotations .copy ()
407
+ scheduler_labels = base_labels .copy ()
408
+ if "metadata" in scheduler_spec_part :
409
+ scheduler_annotations .update (scheduler_spec_part .get ("metadata" , {}).get ("annotations" , {}))
410
+ scheduler_labels .update (scheduler_spec_part .get ("metadata" , {}).get ("labels" , {}))
411
+
412
+ desired_dep_spec = build_scheduler_deployment_spec (
413
+ name , scheduler_spec_part .get ("spec" ), scheduler_annotations , scheduler_labels
414
+ )
415
+ scheduler_deployment = await Deployment (
416
+ SCHEDULER_NAME_TEMPLATE .format (cluster_name = name ), namespace = namespace # Use name
417
+ )
418
+ if await scheduler_deployment .exists ():
419
+ await scheduler_deployment .patch (desired_dep_spec )
420
+ logger .info (f"Scheduler deployment { scheduler_deployment .name } patched." )
421
+ else :
422
+ logger .warning (f"Scheduler deployment { scheduler_deployment .name } not found. Recreating." )
423
+ kopf .adopt (desired_dep_spec , owner = meta )
424
+ await scheduler_deployment .create (desired_dep_spec )
425
+
426
+ desired_svc_spec = build_scheduler_service_spec (
427
+ name , scheduler_spec_part .get ("service" ), scheduler_annotations , scheduler_labels
428
+ )
429
+ scheduler_service = await Service (
430
+ SCHEDULER_NAME_TEMPLATE .format (cluster_name = name ), namespace = namespace # Use name
431
+ )
432
+ if await scheduler_service .exists ():
433
+ await scheduler_service .patch (desired_svc_spec )
434
+ logger .info (f"Scheduler service { scheduler_service .name } patched." )
435
+ else :
436
+ logger .warning (f"Scheduler service { scheduler_service .name } not found. Recreating." )
437
+ kopf .adopt (desired_svc_spec , owner = meta )
438
+ await scheduler_service .create (desired_svc_spec )
439
+
440
+ if worker_changed :
441
+ logger .info ("Worker spec changed, reconciling default worker group." )
442
+ worker_spec_part = spec .get ("worker" , {})
443
+
444
+ worker_annotations = base_annotations .copy ()
445
+ worker_labels = base_labels .copy ()
446
+ if "metadata" in worker_spec_part :
447
+ worker_annotations .update (worker_spec_part .get ("metadata" , {}).get ("annotations" , {}))
448
+ worker_labels .update (worker_spec_part .get ("metadata" , {}).get ("labels" , {}))
449
+
450
+ desired_wg_spec = build_default_worker_group_spec (
451
+ name , worker_spec_part , worker_annotations , worker_labels
452
+ )
453
+ worker_group = await DaskWorkerGroup .get (f"{ name } -default" , namespace = namespace )
454
+
455
+ if await worker_group .exists ():
456
+ await worker_group .patch (desired_wg_spec )
457
+ logger .info (f"Worker group { worker_group .name } patched." )
458
+ else :
459
+ logger .warning (f"Worker group { worker_group .name } not found. Recreating." )
460
+ kopf .adopt (desired_wg_spec , owner = meta )
461
+ await worker_group .create (desired_wg_spec )
462
+
463
+ patch .status ["observedGeneration" ] = meta .generation
464
+ logger .info (f"Update handler finished for DaskCluster '{ name } '." )
392
465
393
466
@kopf .on .field ("service" , field = "status" , labels = {"dask.org/component" : "scheduler" })
394
467
async def handle_scheduler_service_status (
@@ -400,23 +473,17 @@ async def handle_scheduler_service_status(
400
473
) -> None :
401
474
assert namespace
402
475
# If the Service is a LoadBalancer with no ingress endpoints mark the cluster as Pending
403
- if spec ["type" ] == "LoadBalancer" and not len (
404
- status .get ("loadBalancer" , {}).get ("ingress" , [])
405
- ):
476
+ if spec ["type" ] == "LoadBalancer" and not len (status .get ("loadBalancer" , {}).get ("ingress" , [])):
406
477
phase = "Pending"
407
478
# Otherwise mark it as Running
408
479
else :
409
480
phase = "Running"
410
- cluster = await DaskCluster .get (
411
- labels ["dask.org/cluster-name" ], namespace = namespace
412
- )
481
+ cluster = await DaskCluster .get (labels ["dask.org/cluster-name" ], namespace = namespace )
413
482
await cluster .patch ({"status" : {"phase" : phase }})
414
483
415
484
416
485
@kopf .on .create ("daskworkergroup.kubernetes.dask.org" )
417
- async def daskworkergroup_create (
418
- body : kopf .Body , namespace : str | None , logger : kopf .Logger , ** kwargs : Any
419
- ) -> None :
486
+ async def daskworkergroup_create (body : kopf .Body , namespace : str | None , logger : kopf .Logger , ** kwargs : Any ) -> None :
420
487
assert namespace
421
488
wg = await DaskWorkerGroup (body , namespace = namespace )
422
489
cluster = await wg .cluster ()
@@ -463,9 +530,7 @@ async def retire_workers(
463
530
)
464
531
465
532
# Otherwise try gracefully retiring via the RPC
466
- logger .debug (
467
- f"Scaling { worker_group_name } failed via the HTTP API, falling back to the Dask RPC"
468
- )
533
+ logger .debug (f"Scaling { worker_group_name } failed via the HTTP API, falling back to the Dask RPC" )
469
534
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
470
535
with suppress (Exception ):
471
536
comm_address = await get_scheduler_address (
@@ -499,9 +564,7 @@ def retire_workers_lifo(workers, n_workers: int) -> list[str]:
499
564
return [w .name for w in workers [- n_workers :]]
500
565
501
566
502
- async def check_scheduler_idle (
503
- scheduler_service_name : str , namespace : str | None , logger : kopf .Logger
504
- ) -> float :
567
+ async def check_scheduler_idle (scheduler_service_name : str , namespace : str | None , logger : kopf .Logger ) -> float :
505
568
assert namespace
506
569
# Try getting idle time via HTTP API
507
570
dashboard_address = await get_scheduler_address (
@@ -525,9 +588,7 @@ async def check_scheduler_idle(
525
588
)
526
589
527
590
# Otherwise try gracefully checking via the RPC
528
- logger .debug (
529
- f"Checking { scheduler_service_name } idleness failed via the HTTP API, falling back to the Dask RPC"
530
- )
591
+ logger .debug (f"Checking { scheduler_service_name } idleness failed via the HTTP API, falling back to the Dask RPC" )
531
592
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
532
593
with suppress (Exception ):
533
594
comm_address = await get_scheduler_address (
@@ -573,9 +634,7 @@ def idle_since_func(dask_scheduler: Scheduler) -> float:
573
634
return float (idle_since )
574
635
575
636
576
- async def get_desired_workers (
577
- scheduler_service_name : str , namespace : str | None
578
- ) -> Any :
637
+ async def get_desired_workers (scheduler_service_name : str , namespace : str | None ) -> Any :
579
638
assert namespace
580
639
# Try gracefully retiring via the HTTP API
581
640
dashboard_address = await get_scheduler_address (
@@ -602,9 +661,7 @@ async def get_desired_workers(
602
661
async with rpc (comm_address ) as scheduler_comm :
603
662
return await scheduler_comm .adaptive_target ()
604
663
except Exception as e :
605
- raise SchedulerCommError (
606
- "Unable to get number of desired workers from scheduler"
607
- ) from e
664
+ raise SchedulerCommError ("Unable to get number of desired workers from scheduler" ) from e
608
665
609
666
610
667
worker_group_scale_locks : dict [str , asyncio .Lock ] = defaultdict (lambda : asyncio .Lock ())
@@ -669,13 +726,9 @@ async def daskworkergroup_replica_update(
669
726
if "labels" in worker_spec ["metadata" ]:
670
727
labels .update (** worker_spec ["metadata" ]["labels" ])
671
728
672
- batch_size = int (
673
- dask .config .get ("kubernetes.controller.worker-allocation.batch-size" ) or 0
674
- )
729
+ batch_size = int (dask .config .get ("kubernetes.controller.worker-allocation.batch-size" ) or 0 )
675
730
batch_size = min (workers_needed , batch_size ) if batch_size else workers_needed
676
- batch_delay = int (
677
- dask .config .get ("kubernetes.controller.worker-allocation.delay" ) or 0
678
- )
731
+ batch_delay = int (dask .config .get ("kubernetes.controller.worker-allocation.delay" ) or 0 )
679
732
if workers_needed > 0 :
680
733
for _ in range (batch_size ):
681
734
data = build_worker_deployment_spec (
@@ -701,9 +754,7 @@ async def daskworkergroup_replica_update(
701
754
if workers_needed < 0 :
702
755
worker_ids = await retire_workers (
703
756
n_workers = - workers_needed ,
704
- scheduler_service_name = SCHEDULER_NAME_TEMPLATE .format (
705
- cluster_name = cluster_name
706
- ),
757
+ scheduler_service_name = SCHEDULER_NAME_TEMPLATE .format (cluster_name = cluster_name ),
707
758
worker_group_name = name ,
708
759
namespace = namespace ,
709
760
logger = logger ,
@@ -712,15 +763,11 @@ async def daskworkergroup_replica_update(
712
763
for wid in worker_ids :
713
764
worker_deployment = await Deployment (wid , namespace = namespace )
714
765
await worker_deployment .delete ()
715
- logger .info (
716
- f"Scaled worker group { name } down to { desired_workers } workers."
717
- )
766
+ logger .info (f"Scaled worker group { name } down to { desired_workers } workers." )
718
767
719
768
720
769
@kopf .on .delete ("daskworkergroup.kubernetes.dask.org" , optional = True )
721
- async def daskworkergroup_remove (
722
- name : str | None , namespace : str | None , ** __ : Any
723
- ) -> None :
770
+ async def daskworkergroup_remove (name : str | None , namespace : str | None , ** __ : Any ) -> None :
724
771
assert name
725
772
assert namespace
726
773
lock_key = f"{ name } /{ namespace } "
@@ -742,9 +789,7 @@ async def daskjob_create(
742
789
patch .status ["jobStatus" ] = "JobCreated"
743
790
744
791
745
- @kopf .on .field (
746
- "daskjob.kubernetes.dask.org" , field = "status.jobStatus" , new = "JobCreated"
747
- )
792
+ @kopf .on .field ("daskjob.kubernetes.dask.org" , field = "status.jobStatus" , new = "JobCreated" )
748
793
async def daskjob_create_components (
749
794
spec : kopf .Spec ,
750
795
name : str | None ,
@@ -776,9 +821,7 @@ async def daskjob_create_components(
776
821
kopf .adopt (cluster_spec )
777
822
cluster = await DaskCluster (cluster_spec , namespace = namespace )
778
823
await cluster .create ()
779
- logger .info (
780
- f"Cluster { cluster_spec ['metadata' ]['name' ]} for job { name } created in { namespace } ."
781
- )
824
+ logger .info (f"Cluster { cluster_spec ['metadata' ]['name' ]} for job { name } created in { namespace } ." )
782
825
783
826
labels = _get_labels (meta )
784
827
annotations = _get_annotations (meta )
@@ -881,9 +924,7 @@ async def handle_runner_status_change_failed(
881
924
882
925
883
926
@kopf .on .create ("daskautoscaler.kubernetes.dask.org" )
884
- async def daskautoscaler_create (
885
- body : kopf .Body , logger : kopf .Logger , ** __ : Any
886
- ) -> None :
927
+ async def daskautoscaler_create (body : kopf .Body , logger : kopf .Logger , ** __ : Any ) -> None :
887
928
"""When an autoscaler is created make it a child of the associated cluster for cascade deletion."""
888
929
autoscaler = await DaskAutoscaler (body )
889
930
cluster = await autoscaler .cluster ()
@@ -916,16 +957,10 @@ async def daskautoscaler_adapt(
916
957
return
917
958
918
959
autoscaler = await DaskAutoscaler .get (name , namespace = namespace )
919
- worker_group = await DaskWorkerGroup .get (
920
- f"{ spec ['cluster' ]} -default" , namespace = namespace
921
- )
960
+ worker_group = await DaskWorkerGroup .get (f"{ spec ['cluster' ]} -default" , namespace = namespace )
922
961
923
962
current_replicas = worker_group .replicas
924
- cooldown_until = float (
925
- autoscaler .annotations .get (
926
- DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION , time .time ()
927
- )
928
- )
963
+ cooldown_until = float (autoscaler .annotations .get (DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION , time .time ()))
929
964
930
965
# Cooldown autoscaling to prevent thrashing
931
966
if time .time () < cooldown_until :
@@ -957,9 +992,7 @@ async def daskautoscaler_adapt(
957
992
958
993
cooldown_until = time .time () + 15
959
994
960
- await autoscaler .annotate (
961
- {DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : str (cooldown_until )}
962
- )
995
+ await autoscaler .annotate ({DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : str (cooldown_until )})
963
996
964
997
logger .info (
965
998
"Autoscaler updated %s worker count from %d to %d" ,
@@ -968,9 +1001,7 @@ async def daskautoscaler_adapt(
968
1001
desired_workers ,
969
1002
)
970
1003
else :
971
- logger .debug (
972
- "Not autoscaling %s with %d workers" , spec ["cluster" ], current_replicas
973
- )
1004
+ logger .debug ("Not autoscaling %s with %d workers" , spec ["cluster" ], current_replicas )
974
1005
975
1006
976
1007
@kopf .timer ("daskcluster.kubernetes.dask.org" , interval = 5.0 )
@@ -990,9 +1021,7 @@ async def daskcluster_autoshutdown(
990
1021
logger = logger ,
991
1022
)
992
1023
except Exception : # TODO: Not use broad "Exception" catch here
993
- logger .warning (
994
- "Unable to connect to scheduler, skipping autoshutdown check."
995
- )
1024
+ logger .warning ("Unable to connect to scheduler, skipping autoshutdown check." )
996
1025
return
997
1026
if idle_since and time .time () > idle_since + idle_timeout :
998
1027
cluster = await DaskCluster .get (name , namespace = namespace )
0 commit comments