Skip to content

Commit 3d3cd7d

Browse files
authored
refactor: recovery attributes in config (#329)
Since we have added recovery attempts with a previous commit, that was hard coded with a value in the code as a constant. We want to make it configurable. For that reason, we moved it into the job configuration file, we added a new data structure with a default value and we updated all examples to use this new structure.
1 parent 93e28fa commit 3d3cd7d

File tree

8 files changed

+34
-10
lines changed

8 files changed

+34
-10
lines changed

examples/llama3/auto/linear-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ fwd_policy: "static"
1010
job_id: "job13"
1111
# maximum number of requests in flight at any given point in time
1212
max_inflight: 4
13-
recover: False
13+
recovery:
14+
enable: false # default: true
15+
attempts: 3 # default: 5
1416

1517
flow_graph:
1618
s-0:

examples/llama3/static/linear-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ fwd_policy: "static"
1010
job_id: "job12"
1111
# maximum number of requests in flight at any given point in time
1212
max_inflight: 4
13-
recover: False
13+
recovery:
14+
enable: false # default: true
15+
attempts: 3 # default: 5
1416

1517
# Note: IP addresses should be agents'
1618
flow_graph:

examples/resnet152/auto/diamond-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ fwd_policy: "rr"
1111
job_id: "job8"
1212
# maximum number of requests in flight at any given point in time
1313
max_inflight: 8
14-
recover: False
14+
recovery:
15+
enable: false # default: true
16+
attempts: 3 # default: 5
1517

1618
flow_graph:
1719
s-0:

examples/resnet152/auto/linear-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ fwd_policy: "rr"
1111
job_id: "job9"
1212
# maximum number of requests in flight at any given point in time
1313
max_inflight: 8
14-
recover: False
14+
recovery:
15+
enable: false # default: true
16+
attempts: 3 # default: 5
1517

1618
flow_graph:
1719
s-0:

examples/resnet152/static/diamond-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ fwd_policy: "rr"
1414
job_id: "job10"
1515
# maximum number of requests in flight at any given point in time
1616
max_inflight: 800
17-
recover: False
17+
recovery:
18+
enable: false # default: true
19+
attempts: 3 # default: 5
1820

1921
# Note: IP addresses should be agents'
2022
flow_graph:

examples/resnet152/static/linear-no-recover.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ fwd_policy: "rr"
1111
job_id: "job11"
1212
# maximum number of requests in flight at any given point in time
1313
max_inflight: 8
14-
recover: False
14+
recovery:
15+
enable: false # default: true
16+
attempts: 3 # default: 5
1517

1618
# Note: IP addresses should be agents'
1719
flow_graph:

infscale/configs/job.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class WorkerData:
8585
recover: bool = False
8686

8787

88+
@dataclass
89+
class Recovery:
90+
"""Specification about recovery."""
91+
92+
enable: bool = True
93+
attempts: int = 5
94+
95+
8896
@dataclass
8997
class ServeConfig:
9098
"""Class for keeping config values of serve specification."""
@@ -228,18 +236,23 @@ class JobConfig:
228236
flow_graph: dict[str, list[WorldInfo]]
229237
dataset: Dataset
230238
job_id: str
239+
recovery: Recovery | dict | None = None
231240
nfaults: int = 0
232241
micro_batch_size: int = 8
233242
fwd_policy: str = "random"
234243
max_inflight: int = 1
235-
recover: bool = True
236244
force_terminate: bool = False
237245

238246
# this will be set by controller based on its configuration
239247
reqgen_config: GenConfig | None = None
240248

241249
def __post_init__(self) -> None:
242250
"""Handle post init class variables."""
251+
if self.recovery is None:
252+
self.recovery = Recovery()
253+
elif not isinstance(self.recovery, Recovery):
254+
self.recovery = Recovery(**self.recovery)
255+
243256
for k in list(self.flow_graph.keys()):
244257
for i, item in enumerate(self.flow_graph[k]):
245258
world_info = item if isinstance(item, WorldInfo) else WorldInfo(**item)

infscale/controller/job_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from infscale.controller.controller import Controller
5151

5252
MAX_RES_RECOVER_RETRIES = 8 # max resources retries with exponential backoff.
53-
MAX_DEPLOY_RECOVER_RETRIES = 5 # maximum deploy recovery retries
5453

5554

5655
logger = None
@@ -653,7 +652,7 @@ def cond_stopped(self):
653652

654653
async def cond_recovery(self):
655654
"""Handle the transition to failed."""
656-
if self._recovery_count == MAX_DEPLOY_RECOVER_RETRIES:
655+
if self._recovery_count == self.context._cur_cfg.recovery.attempts:
657656
failed_wrk_ids = self._get_failed_wrk_ids()
658657
await self._remove_pipeline_n_update(failed_wrk_ids, self.context._cur_cfg)
659658

@@ -755,7 +754,7 @@ async def do_wrk_cond(self, wid: str, status: WorkerStatus) -> None:
755754
await self.send_check_loop_command()
756755

757756
# if failure happens while starting, current config is None
758-
if self._cur_cfg is not None and self._cur_cfg.recover:
757+
if self._cur_cfg is not None and self._cur_cfg.recovery.enable:
759758
await self.cond_recovery()
760759

761760
return

0 commit comments

Comments
 (0)