Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,76 @@
See the
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
for more information.

Pod Overlay
===========

You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
the ``kubernetes`` metadata on your role. The value can be:

- A dict with the overlay structure
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)

Merge semantics:
- **dict**: recursive merge (upsert)
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
- **primitives**: replace

.. code:: python

from torchx.specs import Role

# Dict overlay - lists append, tuples replace
role = Role(
name="trainer",
image="my-image:latest",
entrypoint="train.py",
metadata={
"kubernetes": {
"spec": {
"nodeSelector": {"gpu": "true"},
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
}
}
}
)

# File URI overlay
role = Role(
name="trainer",
image="my-image:latest",
entrypoint="train.py",
metadata={
"kubernetes": "file:///path/to/pod_overlay.yaml"
}
)

CLI usage with builtin components:

.. code:: bash

$ torchx run --scheduler kubernetes dist.ddp \\
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
--script train.py

Example ``pod_overlay.yaml``:

.. code:: yaml

spec:
nodeSelector:
node.kubernetes.io/instance-type: p4d.24xlarge
tolerations:
- key: nvidia.com/gpu
operator: Exists
effect: NoSchedule
volumes: !!python/tuple
- name: my-volume
emptyDir: {}

The overlay is deep-merged with the generated pod, preserving existing fields
and adding or overriding specified ones.
"""

import json
Expand All @@ -45,6 +115,7 @@
Tuple,
TYPE_CHECKING,
TypedDict,
Union,
)

import torchx
Expand Down Expand Up @@ -97,6 +168,40 @@
RESERVED_MILLICPU = 100
RESERVED_MEMMB = 1024


def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
"""Apply overlay dict to V1Pod object, merging nested fields.

Merge semantics:
- dict: upsert (recursive merge)
- list: append by default, replace if tuple
- primitives: replace
"""
from kubernetes import client

api = client.ApiClient()
pod_dict = api.sanitize_for_serialization(pod)

def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
for key, value in overlay.items():
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
deep_merge(base[key], value)
elif isinstance(value, tuple):
base[key] = list(value)
elif (
isinstance(value, list) and key in base and isinstance(base[key], list)
):
base[key].extend(value)
else:
base[key] = value

deep_merge(pod_dict, overlay)

merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
pod.spec = merged_pod.spec
pod.metadata = merged_pod.metadata


RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
RetryPolicy.REPLICA: [],
RetryPolicy.APPLICATION: [
Expand Down Expand Up @@ -402,6 +507,17 @@ def app_to_resource(
replica_role.env["TORCHX_IMAGE"] = replica_role.image

pod = role_to_pod(name, replica_role, service_account)
if k8s_metadata := role.metadata.get("kubernetes"):
if isinstance(k8s_metadata, str):
import fsspec # pyre-ignore[21]

with fsspec.open(k8s_metadata, "r") as f:
k8s_metadata = yaml.unsafe_load(f)
elif not isinstance(k8s_metadata, dict):
raise ValueError(
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
)
_apply_pod_overlay(pod, k8s_metadata)
pod.metadata.labels.update(
pod_labels(
app=app,
Expand Down Expand Up @@ -636,7 +752,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
else:
raise

return f'{namespace}:{resp["metadata"]["name"]}'
return f"{namespace}:{resp['metadata']['name']}"

def _submit_dryrun(
self, app: AppDef, cfg: KubernetesOpts
Expand Down
Loading
Loading