Skip to content

Commit 8a51e23

Browse files
Added update listenert to dask cluster
1 parent 883f889 commit 8a51e23

File tree

1 file changed

+113
-84
lines changed

1 file changed

+113
-84
lines changed

dask_kubernetes/operator/controller/controller.py

Lines changed: 113 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@
4040

4141
KUBERNETES_DATETIME_FORMAT: Final[str] = "%Y-%m-%dT%H:%M:%SZ"
4242

43-
DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[
44-
str
45-
] = "kubernetes.dask.org/cooldown-until"
43+
DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: Final[str] = "kubernetes.dask.org/cooldown-until"
4644

4745
# Load operator plugins from other packages
4846
PLUGINS: list[Any] = []
@@ -59,20 +57,15 @@ def _get_annotations(meta: kopf.Meta) -> dict[str, str]:
5957
return {
6058
annotation_key: annotation_value
6159
for annotation_key, annotation_value in meta.annotations.items()
62-
if not any(
63-
annotation_key.startswith(namespace)
64-
for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE
65-
)
60+
if not any(annotation_key.startswith(namespace) for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE)
6661
}
6762

6863

6964
def _get_labels(meta: kopf.Meta) -> dict[str, str]:
7065
return {
7166
label_key: label_value
7267
for label_key, label_value in meta.labels.items()
73-
if not any(
74-
label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE
75-
)
68+
if not any(label_key.startswith(namespace) for namespace in _LABEL_NAMESPACES_TO_IGNORE)
7669
}
7770

7871

@@ -351,21 +344,15 @@ async def daskcluster_create_components(
351344
annotations.update(**scheduler_spec["metadata"]["annotations"])
352345
if "labels" in scheduler_spec["metadata"]:
353346
labels.update(**scheduler_spec["metadata"]["labels"])
354-
data = build_scheduler_deployment_spec(
355-
name, scheduler_spec.get("spec"), annotations, labels
356-
)
347+
data = build_scheduler_deployment_spec(name, scheduler_spec.get("spec"), annotations, labels)
357348
kopf.adopt(data)
358349
scheduler_deployment = await Deployment(data, namespace=namespace)
359350
if not await scheduler_deployment.exists():
360351
await scheduler_deployment.create()
361-
logger.info(
362-
f"Scheduler deployment {scheduler_deployment.name} created in {namespace}."
363-
)
352+
logger.info(f"Scheduler deployment {scheduler_deployment.name} created in {namespace}.")
364353

365354
# Create scheduler service
366-
data = build_scheduler_service_spec(
367-
name, scheduler_spec.get("service"), annotations, labels
368-
)
355+
data = build_scheduler_service_spec(name, scheduler_spec.get("service"), annotations, labels)
369356
kopf.adopt(data)
370357
scheduler_service = await Service(data, namespace=namespace)
371358
if not await scheduler_service.exists():
@@ -389,6 +376,92 @@ async def daskcluster_create_components(
389376

390377
patch.status["phase"] = "Pending"
391378

379+
@kopf.on.update("daskcluster.kubernetes.dask.org")
380+
async def daskcluster_update(
381+
spec: kopf.Spec,
382+
status: kopf.Status,
383+
meta: kopf.Meta,
384+
name: str | None,
385+
namespace: str | None,
386+
diff: kopf.Diff,
387+
patch: kopf.Patch,
388+
logger: kopf.Logger,
389+
**__: Any
390+
):
391+
"""When the DaskCluster resource is updated update all the components."""
392+
assert name
393+
assert namespace
394+
logger.info(f"Handling update for DaskCluster '{name}'")
395+
396+
scheduler_changed = any(op['path'].startswith('/spec/scheduler') for op in diff)
397+
worker_changed = any(op['path'].startswith('/spec/worker') for op in diff)
398+
399+
base_annotations = _get_annotations(meta)
400+
base_labels = _get_labels(meta)
401+
402+
if scheduler_changed:
403+
logger.info("Scheduler spec changed, reconciling scheduler components.")
404+
scheduler_spec_part = spec.get("scheduler", {})
405+
406+
scheduler_annotations = base_annotations.copy()
407+
scheduler_labels = base_labels.copy()
408+
if "metadata" in scheduler_spec_part:
409+
scheduler_annotations.update(scheduler_spec_part.get("metadata", {}).get("annotations", {}))
410+
scheduler_labels.update(scheduler_spec_part.get("metadata", {}).get("labels", {}))
411+
412+
desired_dep_spec = build_scheduler_deployment_spec(
413+
name, scheduler_spec_part.get("spec"), scheduler_annotations, scheduler_labels
414+
)
415+
scheduler_deployment = await Deployment(
416+
SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name
417+
)
418+
if await scheduler_deployment.exists():
419+
await scheduler_deployment.patch(desired_dep_spec)
420+
logger.info(f"Scheduler deployment {scheduler_deployment.name} patched.")
421+
else:
422+
logger.warning(f"Scheduler deployment {scheduler_deployment.name} not found. Recreating.")
423+
kopf.adopt(desired_dep_spec, owner=meta)
424+
await scheduler_deployment.create(desired_dep_spec)
425+
426+
desired_svc_spec = build_scheduler_service_spec(
427+
name, scheduler_spec_part.get("service"), scheduler_annotations, scheduler_labels
428+
)
429+
scheduler_service = await Service(
430+
SCHEDULER_NAME_TEMPLATE.format(cluster_name=name), namespace=namespace # Use name
431+
)
432+
if await scheduler_service.exists():
433+
await scheduler_service.patch(desired_svc_spec)
434+
logger.info(f"Scheduler service {scheduler_service.name} patched.")
435+
else:
436+
logger.warning(f"Scheduler service {scheduler_service.name} not found. Recreating.")
437+
kopf.adopt(desired_svc_spec, owner=meta)
438+
await scheduler_service.create(desired_svc_spec)
439+
440+
if worker_changed:
441+
logger.info("Worker spec changed, reconciling default worker group.")
442+
worker_spec_part = spec.get("worker", {})
443+
444+
worker_annotations = base_annotations.copy()
445+
worker_labels = base_labels.copy()
446+
if "metadata" in worker_spec_part:
447+
worker_annotations.update(worker_spec_part.get("metadata", {}).get("annotations", {}))
448+
worker_labels.update(worker_spec_part.get("metadata", {}).get("labels", {}))
449+
450+
desired_wg_spec = build_default_worker_group_spec(
451+
name, worker_spec_part, worker_annotations, worker_labels
452+
)
453+
worker_group = await DaskWorkerGroup.get(f"{name}-default", namespace=namespace)
454+
455+
if await worker_group.exists():
456+
await worker_group.patch(desired_wg_spec)
457+
logger.info(f"Worker group {worker_group.name} patched.")
458+
else:
459+
logger.warning(f"Worker group {worker_group.name} not found. Recreating.")
460+
kopf.adopt(desired_wg_spec, owner=meta)
461+
await worker_group.create(desired_wg_spec)
462+
463+
patch.status["observedGeneration"] = meta.generation
464+
logger.info(f"Update handler finished for DaskCluster '{name}'.")
392465

393466
@kopf.on.field("service", field="status", labels={"dask.org/component": "scheduler"})
394467
async def handle_scheduler_service_status(
@@ -400,23 +473,17 @@ async def handle_scheduler_service_status(
400473
) -> None:
401474
assert namespace
402475
# If the Service is a LoadBalancer with no ingress endpoints mark the cluster as Pending
403-
if spec["type"] == "LoadBalancer" and not len(
404-
status.get("loadBalancer", {}).get("ingress", [])
405-
):
476+
if spec["type"] == "LoadBalancer" and not len(status.get("loadBalancer", {}).get("ingress", [])):
406477
phase = "Pending"
407478
# Otherwise mark it as Running
408479
else:
409480
phase = "Running"
410-
cluster = await DaskCluster.get(
411-
labels["dask.org/cluster-name"], namespace=namespace
412-
)
481+
cluster = await DaskCluster.get(labels["dask.org/cluster-name"], namespace=namespace)
413482
await cluster.patch({"status": {"phase": phase}})
414483

415484

416485
@kopf.on.create("daskworkergroup.kubernetes.dask.org")
417-
async def daskworkergroup_create(
418-
body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any
419-
) -> None:
486+
async def daskworkergroup_create(body: kopf.Body, namespace: str | None, logger: kopf.Logger, **kwargs: Any) -> None:
420487
assert namespace
421488
wg = await DaskWorkerGroup(body, namespace=namespace)
422489
cluster = await wg.cluster()
@@ -463,9 +530,7 @@ async def retire_workers(
463530
)
464531

465532
# Otherwise try gracefully retiring via the RPC
466-
logger.debug(
467-
f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC"
468-
)
533+
logger.debug(f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC")
469534
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
470535
with suppress(Exception):
471536
comm_address = await get_scheduler_address(
@@ -499,9 +564,7 @@ def retire_workers_lifo(workers, n_workers: int) -> list[str]:
499564
return [w.name for w in workers[-n_workers:]]
500565

501566

502-
async def check_scheduler_idle(
503-
scheduler_service_name: str, namespace: str | None, logger: kopf.Logger
504-
) -> float:
567+
async def check_scheduler_idle(scheduler_service_name: str, namespace: str | None, logger: kopf.Logger) -> float:
505568
assert namespace
506569
# Try getting idle time via HTTP API
507570
dashboard_address = await get_scheduler_address(
@@ -525,9 +588,7 @@ async def check_scheduler_idle(
525588
)
526589

527590
# Otherwise try gracefully checking via the RPC
528-
logger.debug(
529-
f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC"
530-
)
591+
logger.debug(f"Checking {scheduler_service_name} idleness failed via the HTTP API, falling back to the Dask RPC")
531592
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
532593
with suppress(Exception):
533594
comm_address = await get_scheduler_address(
@@ -573,9 +634,7 @@ def idle_since_func(dask_scheduler: Scheduler) -> float:
573634
return float(idle_since)
574635

575636

576-
async def get_desired_workers(
577-
scheduler_service_name: str, namespace: str | None
578-
) -> Any:
637+
async def get_desired_workers(scheduler_service_name: str, namespace: str | None) -> Any:
579638
assert namespace
580639
# Try gracefully retiring via the HTTP API
581640
dashboard_address = await get_scheduler_address(
@@ -602,9 +661,7 @@ async def get_desired_workers(
602661
async with rpc(comm_address) as scheduler_comm:
603662
return await scheduler_comm.adaptive_target()
604663
except Exception as e:
605-
raise SchedulerCommError(
606-
"Unable to get number of desired workers from scheduler"
607-
) from e
664+
raise SchedulerCommError("Unable to get number of desired workers from scheduler") from e
608665

609666

610667
worker_group_scale_locks: dict[str, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
@@ -669,13 +726,9 @@ async def daskworkergroup_replica_update(
669726
if "labels" in worker_spec["metadata"]:
670727
labels.update(**worker_spec["metadata"]["labels"])
671728

672-
batch_size = int(
673-
dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0
674-
)
729+
batch_size = int(dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0)
675730
batch_size = min(workers_needed, batch_size) if batch_size else workers_needed
676-
batch_delay = int(
677-
dask.config.get("kubernetes.controller.worker-allocation.delay") or 0
678-
)
731+
batch_delay = int(dask.config.get("kubernetes.controller.worker-allocation.delay") or 0)
679732
if workers_needed > 0:
680733
for _ in range(batch_size):
681734
data = build_worker_deployment_spec(
@@ -701,9 +754,7 @@ async def daskworkergroup_replica_update(
701754
if workers_needed < 0:
702755
worker_ids = await retire_workers(
703756
n_workers=-workers_needed,
704-
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
705-
cluster_name=cluster_name
706-
),
757+
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
707758
worker_group_name=name,
708759
namespace=namespace,
709760
logger=logger,
@@ -712,15 +763,11 @@ async def daskworkergroup_replica_update(
712763
for wid in worker_ids:
713764
worker_deployment = await Deployment(wid, namespace=namespace)
714765
await worker_deployment.delete()
715-
logger.info(
716-
f"Scaled worker group {name} down to {desired_workers} workers."
717-
)
766+
logger.info(f"Scaled worker group {name} down to {desired_workers} workers.")
718767

719768

720769
@kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True)
721-
async def daskworkergroup_remove(
722-
name: str | None, namespace: str | None, **__: Any
723-
) -> None:
770+
async def daskworkergroup_remove(name: str | None, namespace: str | None, **__: Any) -> None:
724771
assert name
725772
assert namespace
726773
lock_key = f"{name}/{namespace}"
@@ -742,9 +789,7 @@ async def daskjob_create(
742789
patch.status["jobStatus"] = "JobCreated"
743790

744791

745-
@kopf.on.field(
746-
"daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated"
747-
)
792+
@kopf.on.field("daskjob.kubernetes.dask.org", field="status.jobStatus", new="JobCreated")
748793
async def daskjob_create_components(
749794
spec: kopf.Spec,
750795
name: str | None,
@@ -776,9 +821,7 @@ async def daskjob_create_components(
776821
kopf.adopt(cluster_spec)
777822
cluster = await DaskCluster(cluster_spec, namespace=namespace)
778823
await cluster.create()
779-
logger.info(
780-
f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}."
781-
)
824+
logger.info(f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}.")
782825

783826
labels = _get_labels(meta)
784827
annotations = _get_annotations(meta)
@@ -881,9 +924,7 @@ async def handle_runner_status_change_failed(
881924

882925

883926
@kopf.on.create("daskautoscaler.kubernetes.dask.org")
884-
async def daskautoscaler_create(
885-
body: kopf.Body, logger: kopf.Logger, **__: Any
886-
) -> None:
927+
async def daskautoscaler_create(body: kopf.Body, logger: kopf.Logger, **__: Any) -> None:
887928
"""When an autoscaler is created make it a child of the associated cluster for cascade deletion."""
888929
autoscaler = await DaskAutoscaler(body)
889930
cluster = await autoscaler.cluster()
@@ -916,16 +957,10 @@ async def daskautoscaler_adapt(
916957
return
917958

918959
autoscaler = await DaskAutoscaler.get(name, namespace=namespace)
919-
worker_group = await DaskWorkerGroup.get(
920-
f"{spec['cluster']}-default", namespace=namespace
921-
)
960+
worker_group = await DaskWorkerGroup.get(f"{spec['cluster']}-default", namespace=namespace)
922961

923962
current_replicas = worker_group.replicas
924-
cooldown_until = float(
925-
autoscaler.annotations.get(
926-
DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time()
927-
)
928-
)
963+
cooldown_until = float(autoscaler.annotations.get(DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time()))
929964

930965
# Cooldown autoscaling to prevent thrashing
931966
if time.time() < cooldown_until:
@@ -957,9 +992,7 @@ async def daskautoscaler_adapt(
957992

958993
cooldown_until = time.time() + 15
959994

960-
await autoscaler.annotate(
961-
{DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)}
962-
)
995+
await autoscaler.annotate({DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until)})
963996

964997
logger.info(
965998
"Autoscaler updated %s worker count from %d to %d",
@@ -968,9 +1001,7 @@ async def daskautoscaler_adapt(
9681001
desired_workers,
9691002
)
9701003
else:
971-
logger.debug(
972-
"Not autoscaling %s with %d workers", spec["cluster"], current_replicas
973-
)
1004+
logger.debug("Not autoscaling %s with %d workers", spec["cluster"], current_replicas)
9741005

9751006

9761007
@kopf.timer("daskcluster.kubernetes.dask.org", interval=5.0)
@@ -990,9 +1021,7 @@ async def daskcluster_autoshutdown(
9901021
logger=logger,
9911022
)
9921023
except Exception: # TODO: Not use broad "Exception" catch here
993-
logger.warning(
994-
"Unable to connect to scheduler, skipping autoshutdown check."
995-
)
1024+
logger.warning("Unable to connect to scheduler, skipping autoshutdown check.")
9961025
return
9971026
if idle_since and time.time() > idle_since + idle_timeout:
9981027
cluster = await DaskCluster.get(name, namespace=namespace)

0 commit comments

Comments
 (0)