Skip to content

Commit c8000cf

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
handle exception waiting for work (meta-pytorch#287)
Summary: work.wait() can throw so wrap that in a try/catch to handle it gracefully by reporting error to the manager, leading the should_commit to fail Reviewed By: d4l3k Differential Revision: D84880993
1 parent e4d99b5 commit c8000cf

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

.github/workflows/docs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
2626
sudo apt-get install -y protobuf-compiler
2727
28+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
2829
pip install .[dev] -v
2930
3031
pip install -r docs/requirements.txt

torchft/device_mesh.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
init_device_mesh,
1111
ProcessGroup as BaseProcessGroup,
1212
)
13+
from torch.distributed._mesh_layout import _MeshLayout
1314
from torch.distributed.tensor.device_mesh import _mesh_resources
1415

1516
from torchft.manager import Manager
@@ -69,12 +70,20 @@ def __init__(
6970
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
7071
self.parent = parent
7172
self.flatten_meshes: Dict[str, DeviceMesh] = {}
73+
self._flatten_mapping: Dict[str, "DeviceMesh"] = {}
7274
self._device_type: str
7375
if mesh is not None:
7476
self._device_type = mesh.device_type
77+
mesh_tensor = (
78+
mesh.detach().to(dtype=torch.int).contiguous()
79+
if isinstance(mesh, torch.Tensor)
80+
else torch.tensor(mesh, device="cpu", dtype=torch.int)
81+
)
82+
self._layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
7583
else:
7684
assert parent is not None
7785
self._device_type = parent.device_type
86+
self._layout = parent._layout
7887
self._flatten_mesh_list: tuple[DeviceMesh, ...] = tuple()
7988
self._thread_id: Optional[int] = None
8089
self._hash: Optional[int] = None

torchft/manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,14 +1253,19 @@ def _assert_same_stream(self) -> None:
12531253
def wait(self, timeout: Optional[timedelta] = None) -> bool:
12541254
self._assert_same_stream()
12551255

1256-
with get_stream_context(self._stream):
1257-
self._work.wait()
1258-
self._set_future_callback()
1256+
try:
1257+
with get_stream_context(self._stream):
1258+
self._work.wait()
1259+
self._set_future_callback()
12591260

1260-
with get_stream_context(self._stream):
1261-
self._managed_fut_tail.wait()
1261+
with get_stream_context(self._stream):
1262+
self._managed_fut_tail.wait()
12621263

1263-
return True
1264+
return True
1265+
except Exception as e:
1266+
self._manager._logger.exception(f"got exception waiting for work {e}")
1267+
self._manager.report_error(e)
1268+
return False
12641269

12651270
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
12661271
self._assert_same_stream()

0 commit comments

Comments
 (0)