Skip to content

Commit b22119c

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
handle exception waiting for work (#287)
Summary: work.wait() can throw so wrap that in a try/catch to handle it gracefully by reporting error to the manager, leading the should_commit to fail Differential Revision: D84880993
1 parent b3be7ad commit b22119c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

torchft/device_mesh.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
7070
self.parent = parent
7171
self.flatten_meshes: Dict[str, DeviceMesh] = {}
72+
self._flatten_mapping: Dict[str, "DeviceMesh"] = {}
7273
self._device_type: str
7374
if mesh is not None:
7475
self._device_type = mesh.device_type

torchft/manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,14 +1244,19 @@ def _assert_same_stream(self) -> None:
12441244
def wait(self, timeout: Optional[timedelta] = None) -> bool:
12451245
self._assert_same_stream()
12461246

1247-
with get_stream_context(self._stream):
1248-
self._work.wait()
1249-
self._set_future_callback()
1247+
try:
1248+
with get_stream_context(self._stream):
1249+
self._work.wait()
1250+
self._set_future_callback()
12501251

1251-
with get_stream_context(self._stream):
1252-
self._managed_fut_tail.wait()
1252+
with get_stream_context(self._stream):
1253+
self._managed_fut_tail.wait()
12531254

1254-
return True
1255+
return True
1256+
except Exception as e:
1257+
self._manager._logger.exception(f"got exception waiting for work {e}")
1258+
self._manager.report_error(e)
1259+
return False
12551260

12561261
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
12571262
self._assert_same_stream()

0 commit comments

Comments
 (0)