Skip to content

Commit f013062

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Adding support for MOMENTUM_DIFF and ROWWISE_ADAGRAD optimizer states (#3144)
Summary: Pull Request resolved: #3144 This diff extends the Model Delta Tracker to support two new tracking modes: `MOMENTUM_DIFF` and `ROWWISE_ADAGRAD`, which enable tracking of rowwise optimizer states for more sophisticated gradient analysis. Differential Revision: D76918891
1 parent b5f15b4 commit f013062

File tree

3 files changed

+344
-12
lines changed

3 files changed

+344
-12
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 170 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@
88
# pyre-strict
99
import logging as logger
1010
from collections import Counter, OrderedDict
11-
from typing import Dict, Iterable, List, Optional
11+
from typing import Dict, Iterable, List, Optional, Tuple
1212

1313
import torch
14+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
15+
from fbgemm_gpu.split_table_batched_embeddings_ops import (
16+
SplitTableBatchedEmbeddingBagsCodegen,
17+
)
1418

1519
from torch import nn
20+
from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding
1621

1722
from torchrec.distributed.embedding import ShardedEmbeddingCollection
1823
from torchrec.distributed.embedding_lookup import (
24+
BatchedFusedEmbeddingBag,
1925
GroupedEmbeddingsLookup,
2026
GroupedPooledEmbeddingsLookup,
2127
)
@@ -26,6 +32,8 @@
2632
EmbdUpdateMode,
2733
TrackingMode,
2834
)
35+
from torchrec.distributed.utils import none_throws
36+
2937
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3038

3139
UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = {
@@ -42,6 +50,14 @@
4250
# This mode supports approximate top-k delta-row selection, can be
4351
# obtained by running momentum.norm().topk().
4452
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
53+
# MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
54+
# Within each publishing interval, we track the starting value of this running
55+
# sum on all used rows and then do a lookup when ``get_delta`` is called to query
56+
# the latest sum. Then we can compute the delta of the two values and return them
57+
# together with the row ids.
58+
TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST,
59+
# The same as MOMENTUM_DIFF. Adding for backward compatibility.
60+
TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST,
4561
}
4662

4763
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -99,6 +115,7 @@ def __init__(
99115

100116
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
101117
self.tracked_modules: Dict[str, nn.Module] = {}
118+
self.table_to_fqn: Dict[str, str] = {}
102119
self.feature_to_fqn: Dict[str, str] = {}
103120
# Generate the mapping from FQN to feature names.
104121
self.fqn_to_feature_names()
@@ -180,6 +197,11 @@ def record_lookup(
180197
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181198
elif self._mode == TrackingMode.MOMENTUM_LAST:
182199
self.record_momentum(emb_module, kjt)
200+
elif (
201+
self._mode == TrackingMode.MOMENTUM_DIFF
202+
or self._mode == TrackingMode.ROWWISE_ADAGRAD
203+
):
204+
self.record_rowwise_optim_state(emb_module, kjt)
183205
else:
184206
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
185207

@@ -278,6 +300,60 @@ def record_momentum(
278300
states=per_key_states,
279301
)
280302

303+
def record_rowwise_optim_state(
304+
self,
305+
emb_module: nn.Module,
306+
kjt: KeyedJaggedTensor,
307+
) -> None:
308+
opt_states: List[List[torch.Tensor]] = (
309+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
310+
# `split_optimizer_states`.
311+
emb_module._emb_module.split_optimizer_states()
312+
)
313+
proxy: torch.Tensor = torch.cat([state[0] for state in opt_states])
314+
states = proxy[kjt.values()]
315+
assert (
316+
kjt.values().numel() == states.numel()
317+
), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} "
318+
offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
319+
torch.tensor(kjt.length_per_key(), dtype=torch.int64)
320+
)
321+
for i, key in enumerate(kjt.keys()):
322+
fqn = self.feature_to_fqn[key]
323+
per_key_states = states[offsets[i] : offsets[i + 1]]
324+
self.store.append(
325+
batch_idx=self.curr_batch_idx,
326+
table_fqn=fqn,
327+
ids=kjt[key].values(),
328+
states=per_key_states,
329+
)
330+
331+
def get_latest(self) -> Dict[str, torch.Tensor]:
332+
ret: Dict[str, torch.Tensor] = {}
333+
for module in self.tracked_modules.values():
334+
# pyre-fixme[29]:
335+
for lookup in module._lookups:
336+
for embs_module in lookup._emb_modules:
337+
assert isinstance(
338+
embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding)
339+
), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found"
340+
tbe = embs_module._emb_module
341+
342+
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)
343+
table_names = [t.name for t in embs_module._config.embedding_tables]
344+
opt_states = tbe.split_optimizer_states()
345+
assert len(table_names) == len(opt_states)
346+
347+
for i, table_name in enumerate(table_names):
348+
emb_fqn = self.table_to_fqn[table_name]
349+
table_state = opt_states[i][0]
350+
assert (
351+
emb_fqn not in ret
352+
), f"a table with {emb_fqn} already exists"
353+
ret[emb_fqn] = table_state
354+
355+
return ret
356+
281357
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
282358
"""
283359
Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
289365
per_table_delta_rows = self.get_delta(consumer)
290366
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}
291367

292-
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
368+
def get_delta(
369+
self,
370+
consumer: Optional[str] = None,
371+
top_percentage: Optional[float] = 1.0,
372+
per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None,
373+
sorted_by_indices: Optional[bool] = True,
374+
) -> Dict[str, DeltaRows]:
293375
"""
294376
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
295377
@@ -314,6 +396,65 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
314396
self.per_consumer_batch_idx[consumer] = index_end
315397
if self._delete_on_read:
316398
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
399+
400+
if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD):
401+
square_sum_map = self.get_latest()
402+
for fqn, rows in tracker_rows.items():
403+
assert (
404+
fqn in square_sum_map
405+
), f"{fqn} not found in {square_sum_map.keys()}"
406+
# compute delta sum_t(g^2) for t in [t1, t2] through
407+
# sum_t2(g^2) - sum_t1(g^2)
408+
# pyre-fixme[58]: `-` is not supported for operand types `Tensor`
409+
# and `Optional[Tensor]`.
410+
rows.states = square_sum_map[fqn][rows.ids] - rows.states
411+
412+
if rows.states is not None:
413+
default_k = rows.states.size(-1)
414+
top_k = (
415+
int(top_percentage * default_k)
416+
if top_percentage is not None
417+
else default_k
418+
)
419+
420+
if (
421+
per_table_percentage is not None
422+
and per_table_percentage.get(fqn) is not None
423+
):
424+
per_table_k = int(per_table_percentage[fqn][0] * default_k)
425+
policy = per_table_percentage[fqn][1]
426+
427+
if policy == "MIN":
428+
top_k = min(top_k, per_table_k)
429+
elif policy == "MAX":
430+
top_k = max(top_k, per_table_k)
431+
elif policy == "OVERRIDE":
432+
top_k = per_table_k
433+
else:
434+
logger.warning(
435+
f"Unknown policy {policy}, will keep using original top_k {top_k}"
436+
)
437+
438+
logger.info(f"get_unique {fqn=} {top_k=} {default_k=}")
439+
440+
if top_k >= default_k:
441+
continue
442+
443+
if sorted_by_indices:
444+
sorted_indices, _ = torch.sort(
445+
torch.topk(
446+
none_throws(rows.states), top_k, sorted=False
447+
).indices,
448+
stable=False,
449+
)
450+
rows.ids = rows.ids[sorted_indices]
451+
rows.states = none_throws(rows.states)[sorted_indices]
452+
else:
453+
rows.states, indices = torch.topk(
454+
none_throws(rows.states), top_k, sorted=False
455+
)
456+
rows.ids = rows.ids[indices]
457+
317458
return tracker_rows
318459

319460
def get_tracked_modules(self) -> Dict[str, nn.Module]:
@@ -330,7 +471,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
330471
return self._fqn_to_feature_map
331472

332473
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
333-
table_to_fqn: Dict[str, str] = OrderedDict()
334474
for fqn, named_module in self._model.named_modules():
335475
split_fqn = fqn.split(".")
336476
# Skipping partial FQNs present in fqns_to_skip
@@ -356,13 +496,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
356496
# will incorrectly match fqn with all the table names that have the same prefix
357497
if table_name in split_fqn:
358498
embedding_fqn = self._clean_fqn_fn(fqn)
359-
if table_name in table_to_fqn:
499+
if table_name in self.table_to_fqn:
360500
# Sanity check for validating that we don't have more then one table mapping to same fqn.
361501
logger.warning(
362-
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
502+
f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
363503
)
364-
table_to_fqn[table_name] = embedding_fqn
365-
logger.info(f"Table to fqn: {table_to_fqn}")
504+
self.table_to_fqn[table_name] = embedding_fqn
505+
logger.info(f"Table to fqn: {self.table_to_fqn}")
366506
flatten_names = [
367507
name for names in table_to_feature_names.values() for name in names
368508
]
@@ -375,15 +515,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
375515

376516
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
377517
for table_name in table_to_feature_names:
378-
if table_name not in table_to_fqn:
518+
if table_name not in self.table_to_fqn:
379519
# This is likely unexpected, where we can't locate the FQN associated with this table.
380520
logger.warning(
381-
f"Table {table_name} not found in {table_to_fqn}, skipping"
521+
f"Table {table_name} not found in {self.table_to_fqn}, skipping"
382522
)
383523
continue
384-
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
385-
table_name
386-
]
524+
fqn_to_feature_names[self.table_to_fqn[table_name]] = (
525+
table_to_feature_names[table_name]
526+
)
387527
self._fqn_to_feature_map = fqn_to_feature_names
388528
return fqn_to_feature_names
389529

@@ -451,6 +591,24 @@ def _validate_and_init_tracker_fns(self) -> None:
451591
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
452592
)
453593
lookup.register_optim_state_tracker_fn(self.record_lookup)
594+
elif (
595+
self._mode == TrackingMode.ROWWISE_ADAGRAD
596+
or self._mode == TrackingMode.MOMENTUM_DIFF
597+
):
598+
# pyre-ignore[29]:
599+
for lookup in module._lookups:
600+
assert isinstance(
601+
lookup,
602+
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
603+
) and all(
604+
# TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
605+
# pyre-ignore[16]:
606+
emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
607+
# pyre-ignore[16]:
608+
or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
609+
for emb in lookup._emb_modules
610+
)
611+
lookup.register_optim_state_tracker_fn(self.record_lookup)
454612
else:
455613
raise NotImplementedError(
456614
f"Tracking mode {self._mode} is not supported"

0 commit comments

Comments
 (0)