88# pyre-strict
99import logging as logger
1010from collections import Counter , OrderedDict
11- from typing import Dict , Iterable , List , Optional
11+ from typing import Dict , Iterable , List , Optional , Tuple
1212
1313import torch
14+ from fbgemm_gpu .split_embedding_configs import EmbOptimType as OptimType
15+ from fbgemm_gpu .split_table_batched_embeddings_ops import (
16+ SplitTableBatchedEmbeddingBagsCodegen ,
17+ )
1418
1519from torch import nn
20+ from torchrec .distributed .batched_embedding_kernel import BatchedFusedEmbedding
1621
1722from torchrec .distributed .embedding import ShardedEmbeddingCollection
1823from torchrec .distributed .embedding_lookup import (
24+ BatchedFusedEmbeddingBag ,
1925 GroupedEmbeddingsLookup ,
2026 GroupedPooledEmbeddingsLookup ,
2127)
2632 EmbdUpdateMode ,
2733 TrackingMode ,
2834)
35+ from torchrec .distributed .utils import none_throws
36+
2937from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
3038
3139UPDATE_MODE_MAP : Dict [TrackingMode , EmbdUpdateMode ] = {
4250 # This mode supports approximate top-k delta-row selection, can be
4351 # obtained by running momentum.norm().topk().
4452 TrackingMode .MOMENTUM_LAST : EmbdUpdateMode .LAST ,
53+ # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
54+ # Within each publishing interval, we track the starting value of this running
55+ # sum on all used rows and then do a lookup when ``get_delta`` is called to query
56+ # the latest sum. Then we can compute the delta of the two values and return them
57+ # together with the row ids.
58+ TrackingMode .MOMENTUM_DIFF : EmbdUpdateMode .FIRST ,
59+ # The same as MOMENTUM_DIFF. Adding for backward compatibility.
60+ TrackingMode .ROWWISE_ADAGRAD : EmbdUpdateMode .FIRST ,
4561}
4662
4763# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -99,6 +115,7 @@ def __init__(
99115
100116 # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
101117 self .tracked_modules : Dict [str , nn .Module ] = {}
118+ self .table_to_fqn : Dict [str , str ] = {}
102119 self .feature_to_fqn : Dict [str , str ] = {}
103120 # Generate the mapping from FQN to feature names.
104121 self .fqn_to_feature_names ()
@@ -180,6 +197,11 @@ def record_lookup(
180197 # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181198 elif self ._mode == TrackingMode .MOMENTUM_LAST :
182199 self .record_momentum (emb_module , kjt )
200+ elif (
201+ self ._mode == TrackingMode .MOMENTUM_DIFF
202+ or self ._mode == TrackingMode .ROWWISE_ADAGRAD
203+ ):
204+ self .record_rowwise_optim_state (emb_module , kjt )
183205 else :
184206 raise NotImplementedError (f"Tracking mode { self ._mode } is not supported" )
185207
@@ -278,6 +300,60 @@ def record_momentum(
278300 states = per_key_states ,
279301 )
280302
303+ def record_rowwise_optim_state (
304+ self ,
305+ emb_module : nn .Module ,
306+ kjt : KeyedJaggedTensor ,
307+ ) -> None :
308+ opt_states : List [List [torch .Tensor ]] = (
309+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
310+ # `split_optimizer_states`.
311+ emb_module ._emb_module .split_optimizer_states ()
312+ )
313+ proxy : torch .Tensor = torch .cat ([state [0 ] for state in opt_states ])
314+ states = proxy [kjt .values ()]
315+ assert (
316+ kjt .values ().numel () == states .numel ()
317+ ), f"number of ids and states mismatch, expect { kjt .values ()= } , { kjt .values ().numel ()} , but got { states .numel ()} "
318+ offsets : torch .Tensor = torch .ops .fbgemm .asynchronous_complete_cumsum (
319+ torch .tensor (kjt .length_per_key (), dtype = torch .int64 )
320+ )
321+ for i , key in enumerate (kjt .keys ()):
322+ fqn = self .feature_to_fqn [key ]
323+ per_key_states = states [offsets [i ] : offsets [i + 1 ]]
324+ self .store .append (
325+ batch_idx = self .curr_batch_idx ,
326+ table_fqn = fqn ,
327+ ids = kjt [key ].values (),
328+ states = per_key_states ,
329+ )
330+
331+ def get_latest (self ) -> Dict [str , torch .Tensor ]:
332+ ret : Dict [str , torch .Tensor ] = {}
333+ for module in self .tracked_modules .values ():
334+ # pyre-fixme[29]:
335+ for lookup in module ._lookups :
336+ for embs_module in lookup ._emb_modules :
337+ assert isinstance (
338+ embs_module , (BatchedFusedEmbeddingBag , BatchedFusedEmbedding )
339+ ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but { type (embs_module )} found"
340+ tbe = embs_module ._emb_module
341+
342+ assert isinstance (tbe , SplitTableBatchedEmbeddingBagsCodegen )
343+ table_names = [t .name for t in embs_module ._config .embedding_tables ]
344+ opt_states = tbe .split_optimizer_states ()
345+ assert len (table_names ) == len (opt_states )
346+
347+ for i , table_name in enumerate (table_names ):
348+ emb_fqn = self .table_to_fqn [table_name ]
349+ table_state = opt_states [i ][0 ]
350+ assert (
351+ emb_fqn not in ret
352+ ), f"a table with { emb_fqn } already exists"
353+ ret [emb_fqn ] = table_state
354+
355+ return ret
356+
281357 def get_delta_ids (self , consumer : Optional [str ] = None ) -> Dict [str , torch .Tensor ]:
282358 """
283359 Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
289365 per_table_delta_rows = self .get_delta (consumer )
290366 return {fqn : delta_rows .ids for fqn , delta_rows in per_table_delta_rows .items ()}
291367
292- def get_delta (self , consumer : Optional [str ] = None ) -> Dict [str , DeltaRows ]:
368+ def get_delta (
369+ self ,
370+ consumer : Optional [str ] = None ,
371+ top_percentage : Optional [float ] = 1.0 ,
372+ per_table_percentage : Optional [Dict [str , Tuple [float , str ]]] = None ,
373+ sorted_by_indices : Optional [bool ] = True ,
374+ ) -> Dict [str , DeltaRows ]:
293375 """
294376 Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
295377
@@ -314,6 +396,65 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
314396 self .per_consumer_batch_idx [consumer ] = index_end
315397 if self ._delete_on_read :
316398 self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
399+
400+ if self ._mode in (TrackingMode .MOMENTUM_DIFF , TrackingMode .ROWWISE_ADAGRAD ):
401+ square_sum_map = self .get_latest ()
402+ for fqn , rows in tracker_rows .items ():
403+ assert (
404+ fqn in square_sum_map
405+ ), f"{ fqn } not found in { square_sum_map .keys ()} "
406+ # compute delta sum_t(g^2) for t in [t1, t2] through
407+ # sum_t2(g^2) - sum_t1(g^2)
408+ # pyre-fixme[58]: `-` is not supported for operand types `Tensor`
409+ # and `Optional[Tensor]`.
410+ rows .states = square_sum_map [fqn ][rows .ids ] - rows .states
411+
412+ if rows .states is not None :
413+ default_k = rows .states .size (- 1 )
414+ top_k = (
415+ int (top_percentage * default_k )
416+ if top_percentage is not None
417+ else default_k
418+ )
419+
420+ if (
421+ per_table_percentage is not None
422+ and per_table_percentage .get (fqn ) is not None
423+ ):
424+ per_table_k = int (per_table_percentage [fqn ][0 ] * default_k )
425+ policy = per_table_percentage [fqn ][1 ]
426+
427+ if policy == "MIN" :
428+ top_k = min (top_k , per_table_k )
429+ elif policy == "MAX" :
430+ top_k = max (top_k , per_table_k )
431+ elif policy == "OVERRIDE" :
432+ top_k = per_table_k
433+ else :
434+ logger .warning (
435+ f"Unknown policy { policy } , will keep using original top_k { top_k } "
436+ )
437+
438+ logger .info (f"get_unique { fqn = } { top_k = } { default_k = } " )
439+
440+ if top_k >= default_k :
441+ continue
442+
443+ if sorted_by_indices :
444+ sorted_indices , _ = torch .sort (
445+ torch .topk (
446+ none_throws (rows .states ), top_k , sorted = False
447+ ).indices ,
448+ stable = False ,
449+ )
450+ rows .ids = rows .ids [sorted_indices ]
451+ rows .states = none_throws (rows .states )[sorted_indices ]
452+ else :
453+ rows .states , indices = torch .topk (
454+ none_throws (rows .states ), top_k , sorted = False
455+ )
456+ rows .ids = rows .ids [indices ]
457+
317458 return tracker_rows
318459
319460 def get_tracked_modules (self ) -> Dict [str , nn .Module ]:
@@ -330,7 +471,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
330471 return self ._fqn_to_feature_map
331472
332473 table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
333- table_to_fqn : Dict [str , str ] = OrderedDict ()
334474 for fqn , named_module in self ._model .named_modules ():
335475 split_fqn = fqn .split ("." )
336476 # Skipping partial FQNs present in fqns_to_skip
@@ -356,13 +496,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
356496 # will incorrectly match fqn with all the table names that have the same prefix
357497 if table_name in split_fqn :
358498 embedding_fqn = self ._clean_fqn_fn (fqn )
359- if table_name in table_to_fqn :
499+ if table_name in self . table_to_fqn :
360500 # Sanity check for validating that we don't have more then one table mapping to same fqn.
361501 logger .warning (
362- f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
502+ f"Override { self . table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
363503 )
364- table_to_fqn [table_name ] = embedding_fqn
365- logger .info (f"Table to fqn: { table_to_fqn } " )
504+ self . table_to_fqn [table_name ] = embedding_fqn
505+ logger .info (f"Table to fqn: { self . table_to_fqn } " )
366506 flatten_names = [
367507 name for names in table_to_feature_names .values () for name in names
368508 ]
@@ -375,15 +515,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
375515
376516 fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
377517 for table_name in table_to_feature_names :
378- if table_name not in table_to_fqn :
518+ if table_name not in self . table_to_fqn :
379519 # This is likely unexpected, where we can't locate the FQN associated with this table.
380520 logger .warning (
381- f"Table { table_name } not found in { table_to_fqn } , skipping"
521+ f"Table { table_name } not found in { self . table_to_fqn } , skipping"
382522 )
383523 continue
384- fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
385- table_name
386- ]
524+ fqn_to_feature_names [self . table_to_fqn [table_name ]] = (
525+ table_to_feature_names [ table_name ]
526+ )
387527 self ._fqn_to_feature_map = fqn_to_feature_names
388528 return fqn_to_feature_names
389529
@@ -451,6 +591,24 @@ def _validate_and_init_tracker_fns(self) -> None:
451591 (GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
452592 )
453593 lookup .register_optim_state_tracker_fn (self .record_lookup )
594+ elif (
595+ self ._mode == TrackingMode .ROWWISE_ADAGRAD
596+ or self ._mode == TrackingMode .MOMENTUM_DIFF
597+ ):
598+ # pyre-ignore[29]:
599+ for lookup in module ._lookups :
600+ assert isinstance (
601+ lookup ,
602+ (GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
603+ ) and all (
604+ # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
605+ # pyre-ignore[16]:
606+ emb ._emb_module .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD
607+ # pyre-ignore[16]:
608+ or emb ._emb_module .optimizer == OptimType .PARTIAL_ROWWISE_ADAM
609+ for emb in lookup ._emb_modules
610+ )
611+ lookup .register_optim_state_tracker_fn (self .record_lookup )
454612 else :
455613 raise NotImplementedError (
456614 f"Tracking mode { self ._mode } is not supported"
0 commit comments