Skip to content

Commit 587fcfe

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Adding support for tracking optimizers states in Model Delta Tracker.
Summary: ### Overview This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection. ### Key Changes #### 1. Optimizer State Tracking Support * To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules. * Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable * Tracking calls are invoked after each lookup operation. #### 2. Model Delta Tracker Changes * Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function. * Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode #### 3. New Tracking Mode * Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py") * Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values Differential Revision: D76868111
1 parent 643d221 commit 587fcfe

File tree

7 files changed

+279
-18
lines changed

7 files changed

+279
-18
lines changed

torchrec/distributed/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,7 @@ def compute_and_output_dist(
15151515
):
15161516
embs = lookup(features)
15171517
if self.post_lookup_tracker_fn is not None:
1518-
self.post_lookup_tracker_fn(features, embs)
1518+
self.post_lookup_tracker_fn(self, features, embs)
15191519

15201520
with maybe_annotate_embedding_event(
15211521
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type

torchrec/distributed/embedding_lookup.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from abc import ABC
1212
from collections import OrderedDict
13-
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
13+
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.distributed as dist
@@ -206,6 +206,10 @@ def __init__(
206206
)
207207

208208
self.grouped_configs = grouped_configs
209+
# Model tracker function to tracker optimizer state
210+
self.optim_state_tracker_fn: Optional[
211+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
212+
] = None
209213

210214
def _create_embedding_kernel(
211215
self,
@@ -305,7 +309,13 @@ def forward(
305309
self._feature_splits,
306310
)
307311
for emb_op, features in zip(self._emb_modules, features_by_group):
308-
embeddings.append(emb_op(features).view(-1))
312+
lookup = emb_op(features).view(-1)
313+
embeddings.append(lookup)
314+
315+
# Model tracker optimizer state function, will only be set called
316+
# when model tracker is configured to track optimizer state
317+
if self.optim_state_tracker_fn is not None:
318+
self.optim_state_tracker_fn(emb_op, features, lookup)
309319

310320
return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)
311321

@@ -409,6 +419,19 @@ def purge(self) -> None:
409419
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
410420
emb_module.purge()
411421

422+
def register_optim_state_tracker_fn(
423+
self,
424+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
425+
) -> None:
426+
"""
427+
Model tracker function to tracker optimizer state
428+
429+
Args:
430+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
431+
432+
"""
433+
self.optim_state_tracker_fn = record_fn
434+
412435

413436
class CommOpGradientScaling(torch.autograd.Function):
414437
@staticmethod
@@ -481,6 +504,10 @@ def __init__(
481504
if scale_weight_gradients and get_gradient_division()
482505
else 1
483506
)
507+
# Model tracker function to tracker optimizer state
508+
self.optim_state_tracker_fn: Optional[
509+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
510+
] = None
484511

485512
def _create_embedding_kernel(
486513
self,
@@ -608,7 +635,12 @@ def forward(
608635
features._weights, self._scale_gradient_factor
609636
)
610637

611-
embeddings.append(emb_op(features))
638+
lookup = emb_op(features)
639+
embeddings.append(lookup)
640+
# Model tracker optimizer state function, will only be set called
641+
# when model tracker is configured to track optimizer state
642+
if self.optim_state_tracker_fn is not None:
643+
self.optim_state_tracker_fn(emb_op, features, lookup)
612644

613645
if features.variable_stride_per_key() and len(self._emb_modules) > 1:
614646
stride_per_rank_per_key = list(
@@ -738,6 +770,19 @@ def purge(self) -> None:
738770
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
739771
emb_module.purge()
740772

773+
def register_optim_state_tracker_fn(
774+
self,
775+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
776+
) -> None:
777+
"""
778+
Model tracker function to tracker optimizer state
779+
780+
Args:
781+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
782+
783+
"""
784+
self.optim_state_tracker_fn = record_fn
785+
741786

742787
class MetaInferGroupedEmbeddingsLookup(
743788
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn

torchrec/distributed/embedding_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __init__(
373373
self._lookups: List[nn.Module] = []
374374
self._output_dists: List[nn.Module] = []
375375
self.post_lookup_tracker_fn: Optional[
376-
Callable[[KeyedJaggedTensor, torch.Tensor], None]
376+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
377377
] = None
378378
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None
379379

@@ -426,14 +426,14 @@ def train(self, mode: bool = True): # pyre-ignore[3]
426426

427427
def register_post_lookup_tracker_fn(
428428
self,
429-
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
429+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
430430
) -> None:
431431
"""
432432
Register a function to be called after lookup is done. This is used for
433433
tracking the lookup results and optimizer states.
434434
435435
Args:
436-
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
436+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
437437
438438
"""
439439
if self.post_lookup_tracker_fn is not None:

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,7 @@ def compute_and_output_dist(
14591459
):
14601460
embs = lookup(features)
14611461
if self.post_lookup_tracker_fn is not None:
1462-
self.post_lookup_tracker_fn(features, embs)
1462+
self.post_lookup_tracker_fn(self, features, embs)
14631463

14641464
with maybe_annotate_embedding_event(
14651465
EmbeddingEvent.OUTPUT_DIST,

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
import torch
1414

1515
from torch import nn
16+
1617
from torchrec.distributed.embedding import ShardedEmbeddingCollection
18+
from torchrec.distributed.embedding_lookup import (
19+
GroupedEmbeddingsLookup,
20+
GroupedPooledEmbeddingsLookup,
21+
)
1722
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
1823
from torchrec.distributed.model_tracker.delta_store import DeltaStore
1924
from torchrec.distributed.model_tracker.types import (
@@ -27,9 +32,16 @@
2732
# Only IDs are tracked, no additional state is stored.
2833
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
2934
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
30-
# the earliest embedding values are stored since the last checkpoint or snapshot.
31-
# This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk().
35+
# the earliest embedding values are stored since the last checkpoint
36+
# or snapshot. This mode is used for computing topk delta rows, which
37+
# is currently achieved by running (new_emb - old_emb).norm().topk().
3238
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
39+
# TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that
40+
# the most recent momentum values—capturing the accumulated gradient
41+
# direction and magnitude—are stored since the last batch.
42+
# This mode supports approximate top-k delta-row selection, can be
43+
# obtained by running momentum.norm().topk().
44+
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
3345
}
3446

3547
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -141,7 +153,9 @@ def trigger_compaction(self) -> None:
141153
# Update the current compact index to the end index to avoid duplicate compaction.
142154
self.curr_compact_index = end_idx
143155

144-
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
156+
def record_lookup(
157+
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
158+
) -> None:
145159
"""
146160
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
147161
@@ -152,6 +166,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
152166
(in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).
153167
154168
Args:
169+
emb_module (nn.Module): The embedding module in which the lookup was performed.
155170
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
156171
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
157172
"""
@@ -162,7 +177,9 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
162177
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
163178
elif self._mode == TrackingMode.EMBEDDING:
164179
self.record_embeddings(kjt, states)
165-
180+
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181+
elif self._mode == TrackingMode.MOMENTUM_LAST:
182+
self.record_momentum(emb_module, kjt)
166183
else:
167184
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
168185

@@ -228,6 +245,39 @@ def record_embeddings(
228245
states=torch.cat(per_table_emb[table_fqn]),
229246
)
230247

248+
def record_momentum(
249+
self,
250+
emb_module: nn.Module,
251+
kjt: KeyedJaggedTensor,
252+
) -> None:
253+
# FIXME: this is the momentum from last iteration, use momentum from current iter
254+
# for correctness.
255+
# pyre-ignore Undefined attribute [16]:
256+
momentum = emb_module._emb_module.momentum1_dev
257+
# FIXME: support multiple tables per group, information can be extracted from
258+
# module._config (i.e., GroupedEmbeddingConfig)
259+
# pyre-ignore Undefined attribute [16]:
260+
states = momentum.view(-1, emb_module._config.embedding_dims()[0])[
261+
kjt.values()
262+
].norm(dim=1)
263+
264+
offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
265+
torch.tensor(kjt.length_per_key(), dtype=torch.int64)
266+
)
267+
assert (
268+
kjt.values().numel() == states.numel()
269+
), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} "
270+
271+
for i, key in enumerate(kjt.keys()):
272+
fqn = self.feature_to_fqn[key]
273+
per_key_states = states[offsets[i] : offsets[i + 1]]
274+
self.store.append(
275+
batch_idx=self.curr_batch_idx,
276+
table_fqn=fqn,
277+
ids=kjt[key].values(),
278+
states=per_key_states,
279+
)
280+
231281
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
232282
"""
233283
Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -380,13 +430,31 @@ def _clean_fqn_fn(self, fqn: str) -> str:
380430
def _validate_and_init_tracker_fns(self) -> None:
381431
"To validate the mode is supported for the given module"
382432
for module in self.tracked_modules.values():
433+
# EMBEDDING mode is only supported for ShardedEmbeddingCollection
383434
assert not (
384435
isinstance(module, ShardedEmbeddingBagCollection)
385436
and self._mode == TrackingMode.EMBEDDING
386437
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
387-
# register post lookup function
388-
# pyre-ignore[29]
389-
module.register_post_lookup_tracker_fn(self.record_lookup)
438+
439+
if (
440+
self._mode == TrackingMode.ID_ONLY
441+
or self._mode == TrackingMode.EMBEDDING
442+
):
443+
# register post lookup function
444+
# pyre-ignore[29]
445+
module.register_post_lookup_tracker_fn(self.record_lookup)
446+
elif self._mode == TrackingMode.MOMENTUM_LAST:
447+
# pyre-ignore[29]:
448+
for lookup in module._lookups:
449+
assert isinstance(
450+
lookup,
451+
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
452+
)
453+
lookup.register_optim_state_tracker_fn(self.record_lookup)
454+
else:
455+
raise NotImplementedError(
456+
f"Tracking mode {self._mode} is not supported"
457+
)
390458
# register auto compaction function at odist
391459
if self._auto_compact:
392460
# pyre-ignore[29]

0 commit comments

Comments
 (0)