Skip to content

Commit f928e74

Browse files
committed
trainer: add strict mode for get_job_logs when pod is missing
Signed-off-by: Shanhuizi Jiang <sjiang83@fordham.edu>
1 parent 43b9590 commit f928e74

File tree

7 files changed

+29
-2
lines changed

7 files changed

+29
-2
lines changed

kubeflow/optimizer/backends/kubernetes/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def get_job_logs(
223223
name: str,
224224
trial_name: str | None = None,
225225
follow: bool = False,
226+
strict: bool = False,
226227
) -> Iterator[str]:
227228
"""Get the OptimizationJob logs from a Trial"""
228229
# Determine what trial to get logs from.
@@ -247,6 +248,8 @@ def get_job_logs(
247248
pod_name = c.pod_name
248249
break
249250
if pod_name is None:
251+
if strict:
252+
raise RuntimeError(f"No pod found for Trial {trial_name} step={step}")
250253
return
251254

252255
container_name = constants.METRICS_COLLECTOR_CONTAINER

kubeflow/trainer/api/trainer_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def get_job_logs(
185185
name: str,
186186
step: str = constants.NODE + "-0",
187187
follow: bool | None = False,
188+
strict: bool = False,
188189
) -> Iterator[str]:
189190
"""Get logs from a specific step of a TrainJob.
190191
@@ -200,6 +201,7 @@ def get_job_logs(
200201
name: Name of the TrainJob.
201202
step: Step of the TrainJob to collect logs from, like dataset-initializer or node-0.
202203
follow: Whether to stream logs in realtime as they are produced.
204+
strict: If True, raise an error when no pod is found for the requested step.
203205
204206
Returns:
205207
Iterator of log lines.
@@ -209,7 +211,7 @@ def get_job_logs(
209211
TimeoutError: Timeout to get a TrainJob.
210212
RuntimeError: Failed to get a TrainJob.
211213
"""
212-
return self.backend.get_job_logs(name=name, follow=follow, step=step)
214+
return self.backend.get_job_logs(name=name, follow=follow, step=step, strict=strict)
213215

214216
def get_job_events(self, name: str) -> list[types.Event]:
215217
"""Get events for a TrainJob.

kubeflow/trainer/backends/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_job_logs(
6464
name: str,
6565
follow: bool = False,
6666
step: str = constants.NODE + "-0",
67+
strict: bool = False,
6768
) -> Iterator[str]:
6869
raise NotImplementedError()
6970

kubeflow/trainer/backends/container/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ def get_job_logs(
773773
name: str,
774774
follow: bool = False,
775775
step: str = constants.NODE + "-0",
776+
strict: bool = False,
776777
) -> Iterator[str]:
777778
"""Get logs for a training job by querying container runtime."""
778779
containers = self._get_job_containers(name)

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737

3838
logger = logging.getLogger(__name__)
3939

40+
class PodNotFoundError(RuntimeError):
41+
"""No pod resolved for the requested job/step."""
42+
pass
4043

4144
class KubernetesBackend(RuntimeBackend):
4245
def __init__(self, cfg: KubernetesBackendConfig):
@@ -432,6 +435,7 @@ def get_job_logs(
432435
name: str,
433436
follow: bool = False,
434437
step: str = constants.NODE + "-0",
438+
strict: bool = False,
435439
) -> Iterator[str]:
436440
"""Get the TrainJob logs"""
437441
# Get the TrainJob Pod name.
@@ -440,7 +444,12 @@ def get_job_logs(
440444
if c.status != constants.POD_PENDING and c.name == step:
441445
pod_name = c.pod_name
442446
break
447+
443448
if pod_name is None:
449+
if strict:
450+
raise PodNotFoundError(
451+
f"No pod found for TrainJob {self.namespace}/{name} step={step}"
452+
)
444453
return
445454

446455
# Remove the number for the node step.

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import pytest
3333

3434
from kubeflow.common.types import KubernetesBackendConfig
35-
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
35+
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend, PodNotFoundError
3636
import kubeflow.trainer.backends.kubernetes.utils as utils
3737
from kubeflow.trainer.constants import constants
3838
from kubeflow.trainer.options import (
@@ -1439,6 +1439,16 @@ def test_get_job_logs(kubernetes_backend, test_case):
14391439
print("test execution complete")
14401440

14411441

1442+
def test_get_job_logs_strict_raises_when_pod_missing(kubernetes_backend, monkeypatch):
1443+
tj = get_train_job_data_type(runtime_name=TORCH_RUNTIME, train_job_name=BASIC_TRAIN_JOB_NAME)
1444+
for s in tj.steps:
1445+
s.status = constants.POD_PENDING
1446+
1447+
monkeypatch.setattr(kubernetes_backend, "get_job", lambda name: tj)
1448+
1449+
with pytest.raises(PodNotFoundError):
1450+
list(kubernetes_backend.get_job_logs(BASIC_TRAIN_JOB_NAME, strict=True))
1451+
14421452
@pytest.mark.parametrize(
14431453
"test_case",
14441454
[

kubeflow/trainer/backends/localprocess/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def get_job_logs(
195195
name: str,
196196
follow: bool = False,
197197
step: str = constants.NODE + "-0",
198+
strict: bool = False,
198199
) -> Iterator[str]:
199200
_job = [j for j in self.__local_jobs if j.name == name]
200201
if not _job:

0 commit comments

Comments
 (0)