Skip to content

Commit b3be7ad

Browse files
tushar00jainmeta-codesync[bot]
authored andcommitted
setup basic otel logging (meta-pytorch#281)
Summary: Pull Request resolved: meta-pytorch#281 - setup 3 basic structured logs - quorums: every time a rank changes quorum id - commits: every time a rank commits a step - errors: every time a rank calls abort on process group - allow `otel.py` to initialize loggers in multiple namespaces for each structured loggers - pass `replica_id` to process group so that it can be logged - flag to enable or disable otel logging Reviewed By: d4l3k Differential Revision: D84571482 fbshipit-source-id: 18dcfef40d8e9cfc1a11bb90d167dd3bf0271dd2
1 parent f60f063 commit b3be7ad

File tree

10 files changed

+176
-107
lines changed

10 files changed

+176
-107
lines changed

torchft/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchft.ddp import DistributedDataParallel
99
from torchft.manager import Manager
1010
from torchft.optim import OptimizerWrapper as Optimizer
11+
from torchft.otel import setup_logger
1112
from torchft.process_group import (
1213
ProcessGroupBabyNCCL,
1314
ProcessGroupBabyXCCL,
@@ -16,6 +17,10 @@
1617
ProcessGroupXCCL,
1718
)
1819

20+
setup_logger("torchft_quorums")
21+
setup_logger("torchft_commits")
22+
setup_logger("torchft_errors")
23+
1924
__all__ = (
2025
"DistributedDataParallel",
2126
"DistributedSampler",

torchft/checkpointing/pg_transport_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run(rank: int) -> None:
4747

4848
with _timeit("init_pg"):
4949
pg = ProcessGroupBabyNCCL(timeout=timeout)
50-
pg.configure(store_addr=store_addr, rank=rank, world_size=2)
50+
pg.configure(store_addr=store_addr, replica_id="0", rank=rank, world_size=2)
5151

5252
t = torch.zeros(10, device=device, dtype=torch.float32)
5353
pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)

torchft/checkpointing/pg_transport_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
2727
pg = ProcessGroupGloo()
2828
pg.configure(
2929
store_addr=f"localhost:{store.port}/prefix",
30+
replica_id="0",
3031
rank=rank,
3132
world_size=world_size,
3233
)
@@ -52,6 +53,7 @@ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
5253
pg = ProcessGroupBabyNCCL(timeout=timeout)
5354
pg.configure(
5455
store_addr=f"localhost:{store.port}/prefix",
56+
replica_id="0",
5557
rank=rank,
5658
world_size=world_size,
5759
)
@@ -78,6 +80,7 @@ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
7880
pg = ProcessGroupBabyNCCL(timeout=timeout)
7981
pg.configure(
8082
store_addr=f"localhost:{store.port}/prefix",
83+
replica_id="0",
8184
rank=rank,
8285
world_size=world_size,
8386
)

torchft/device_mesh.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def __init__(
6262
raise ValueError(
6363
"ManagedDeviceMesh doesn't support both mesh and parent are None."
6464
)
65-
self.mesh = mesh
66-
self.mesh_dim_names = mesh_dim_names
65+
self._mesh = mesh
66+
self._mesh_dim_names = mesh_dim_names
6767
self.replicate_pg = replicate_pg
6868
self.replicate_dim = replicate_dim
6969
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
7070
self.parent = parent
7171
self.flatten_meshes: Dict[str, DeviceMesh] = {}
72-
self.device_type: str
72+
self._device_type: str
7373
if mesh is not None:
74-
self.device_type = mesh.device_type
74+
self._device_type = mesh.device_type
7575
else:
7676
assert parent is not None
77-
self.device_type = parent.device_type
77+
self._device_type = parent.device_type
7878
self._flatten_mesh_list: tuple[DeviceMesh, ...] = tuple()
7979
self._thread_id: Optional[int] = None
8080
self._hash: Optional[int] = None
@@ -102,20 +102,20 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
102102
elif mesh_dim_names in self.flatten_meshes:
103103
res_submesh = self.flatten_meshes[mesh_dim_names]
104104
else:
105-
assert self.mesh is not None
106-
res_submesh = self.mesh[mesh_dim_names]
105+
assert self._mesh is not None
106+
res_submesh = self._mesh[mesh_dim_names]
107107
else:
108108
assert isinstance(mesh_dim_names, tuple)
109109
if self.replicate_dim_name not in mesh_dim_names:
110-
assert self.mesh is not None
111-
res_submesh = self.mesh[mesh_dim_names]
110+
assert self._mesh is not None
111+
res_submesh = self._mesh[mesh_dim_names]
112112
else:
113113
mesh_dim_names_wo_replicate = tuple(
114114
n for n in mesh_dim_names if n != self.replicate_dim_name
115115
)
116-
assert self.mesh is not None
116+
assert self._mesh is not None
117117
res_submesh = ManagedDeviceMesh(
118-
self.mesh[mesh_dim_names_wo_replicate],
118+
self._mesh[mesh_dim_names_wo_replicate],
119119
mesh_dim_names,
120120
self.replicate_pg,
121121
mesh_dim_names.index(self.replicate_dim_name),
@@ -125,7 +125,7 @@ def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh
125125
# TODO: find a better way to do this that doesn't depend on device mesh
126126
# internals
127127
root = _mesh_resources.get_root_mesh(self)
128-
_mesh_resources.child_to_root_mapping[res_submesh] = root
128+
res_submesh._root_mesh = root
129129

130130
return res_submesh
131131

@@ -134,7 +134,7 @@ def _real_mesh_dim(self, mesh_dim: int) -> int:
134134

135135
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup:
136136
if isinstance(mesh_dim, str):
137-
dim = self.mesh_dim_names.index(mesh_dim)
137+
dim = self._mesh_dim_names.index(mesh_dim)
138138
else:
139139
dim = 0 if mesh_dim is None else int(mesh_dim)
140140

@@ -143,8 +143,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
143143
elif dim == self.replicate_dim:
144144
return self.replicate_pg
145145
else:
146-
assert self.mesh is not None
147-
return self.mesh.get_group(self._real_mesh_dim(dim))
146+
assert self._mesh is not None
147+
return self._mesh.get_group(self._real_mesh_dim(dim))
148148

149149
def _flatten(
150150
self,
@@ -168,64 +168,64 @@ def size(self, mesh_dim: Optional[int] = None) -> int:
168168
# This is possible during the initialization stage of training.
169169
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
170170
if mesh_dim is None:
171-
if self.mesh is None:
171+
if self._mesh is None:
172172
return replicate_pg_size
173173
else:
174-
assert self.mesh is not None
175-
return self.mesh.size() * replicate_pg_size
174+
assert self._mesh is not None
175+
return self._mesh.size() * replicate_pg_size
176176
elif mesh_dim == self.replicate_dim:
177177
return replicate_pg_size
178178
else:
179-
assert self.mesh is not None
180-
return self.mesh.size(self._real_mesh_dim(mesh_dim))
179+
assert self._mesh is not None
180+
return self._mesh.size(self._real_mesh_dim(mesh_dim))
181181

182182
@property
183183
def ndim(self) -> int:
184-
assert self.mesh is not None
185-
return self.mesh.ndim + 1
184+
assert self._mesh is not None
185+
return self._mesh.ndim + 1
186186

187187
@property
188188
def shape(self) -> tuple[int, ...]:
189-
assert self.mesh is not None
190-
ret: list[int] = list(self.mesh.shape)
189+
assert self._mesh is not None
190+
ret: list[int] = list(self._mesh.shape)
191191
ret.insert(self.replicate_dim, self.replicate_pg.size())
192192
return tuple(ret)
193193

194194
def get_rank(self) -> int:
195-
assert self.mesh is not None
196-
return self.mesh.get_rank()
195+
assert self._mesh is not None
196+
return self._mesh.get_rank()
197197

198198
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
199199
if isinstance(mesh_dim, str):
200-
dim = self.mesh_dim_names.index(mesh_dim)
200+
dim = self._mesh_dim_names.index(mesh_dim)
201201
else:
202202
dim = 0 if mesh_dim is None else int(mesh_dim)
203203

204204
if mesh_dim is None:
205-
if self.mesh is None:
205+
if self._mesh is None:
206206
return get_rank(self.replicate_pg)
207207

208208
assert self.replicate_dim == 0, "replicate_dim must be the first one"
209-
assert self.mesh is not None
210-
other_dim_size = self.mesh.size()
211-
assert self.mesh is not None
212-
other_dim_rank = self.mesh.get_local_rank()
209+
assert self._mesh is not None
210+
other_dim_size = self._mesh.size()
211+
assert self._mesh is not None
212+
other_dim_rank = self._mesh.get_local_rank()
213213
replicate_pg_rank = get_rank(self.replicate_pg)
214214
return other_dim_size * replicate_pg_rank + other_dim_rank
215215
elif dim == self.replicate_dim:
216216
return get_rank(self.replicate_pg)
217217
else:
218-
assert self.mesh is not None
219-
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
218+
assert self._mesh is not None
219+
return self._mesh.get_local_rank(self._real_mesh_dim(dim))
220220

221221
def get_coordinate(self) -> Optional[list[int]]:
222222
"""
223223
Return the relative indices of this rank relative to all
224224
dimensions of the mesh. If this rank is not part of the mesh, return None.
225225
"""
226-
assert self.mesh is not None
226+
assert self._mesh is not None
227227
coordinate = (
228-
self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
228+
self._mesh._coordinate_on_dim if self._mesh._coordinate_on_dim else None
229229
)
230230
if not coordinate:
231231
return coordinate
@@ -239,20 +239,20 @@ def get_all_groups(self) -> list[BaseProcessGroup]:
239239
raise NotImplementedError
240240

241241
def __repr__(self) -> str:
242-
return f"ManagedDeviceMesh(mesh={self.mesh})"
242+
return f"ManagedDeviceMesh(mesh={self._mesh})"
243243

244244
def __hash__(self) -> int:
245245
# lazily compute hash
246246
if not self._hash:
247247
self._hash = hash(
248248
(
249-
self.mesh,
250-
self.mesh_dim_names,
249+
self._mesh,
250+
self._mesh_dim_names,
251251
self.replicate_pg,
252252
self.replicate_dim,
253253
self.replicate_dim_name,
254254
self.parent,
255-
self.device_type,
255+
self._device_type,
256256
)
257257
)
258258
return self._hash

torchft/manager.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,15 @@ def __init__(
216216
before raising an exception. If None, will retry indefinitely.
217217
quorum_retries: the number of times to retry the quorum before crashing
218218
"""
219+
self.quorum_logger: logging.Logger = logging.getLogger("torchft_quorums")
220+
self.commits_logger: logging.Logger = logging.getLogger("torchft_commits")
221+
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
222+
219223
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
220224
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
221225

226+
self._replica_id = replica_id
227+
222228
# Protects state dict
223229
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
224230

@@ -642,6 +648,16 @@ def _async_quorum(
642648
self._participating_replica_rank = None
643649

644650
if quorum_id != self._quorum_id:
651+
self.quorum_logger.info(
652+
"",
653+
extra={
654+
"job_id": os.environ.get("JOB_ID", "unknown"),
655+
"replica_id": self._replica_id,
656+
"rank": self._group_rank,
657+
"quorum_id": quorum_id,
658+
"step": max_step,
659+
},
660+
)
645661
store_prefixed_addr = (
646662
f"{store_address}/torchft/{quorum_id}/{self._group_rank}"
647663
)
@@ -653,7 +669,10 @@ def _async_quorum(
653669
if torch.accelerator.is_available():
654670
torch.accelerator.synchronize()
655671
self._pg.configure(
656-
store_prefixed_addr, replica_rank, replica_world_size
672+
store_prefixed_addr,
673+
self._replica_id if self._replica_id is not None else "0",
674+
replica_rank,
675+
replica_world_size,
657676
)
658677
self._quorum_id = quorum_id
659678
except Exception as e:
@@ -817,6 +836,18 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
817836
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
818837
)
819838

839+
self.commits_logger.info(
840+
"",
841+
extra={
842+
"job_id": os.environ.get("JOB_ID", "unknown"),
843+
"replica_id": self._replica_id,
844+
"rank": self._group_rank,
845+
"quorum_id": self._quorum_id,
846+
"step": self._step,
847+
"commit_result": should_commit,
848+
},
849+
)
850+
820851
self._checkpoint_transport.disallow_checkpoint()
821852

822853
# decide whether we're in a healthy state to increase the step count

0 commit comments

Comments
 (0)