8
8
# pyre-strict
9
9
import logging as logger
10
10
from collections import Counter , OrderedDict
11
- from typing import Dict , Iterable , List , Optional
11
+ from typing import Dict , Iterable , List , Optional , Tuple
12
12
13
13
import torch
14
+ from fbgemm_gpu .split_embedding_configs import EmbOptimType as OptimType
15
+ from fbgemm_gpu .split_table_batched_embeddings_ops import (
16
+ SplitTableBatchedEmbeddingBagsCodegen ,
17
+ )
14
18
15
19
from torch import nn
20
+ from torchrec .distributed .batched_embedding_kernel import BatchedFusedEmbedding
16
21
17
22
from torchrec .distributed .embedding import ShardedEmbeddingCollection
18
23
from torchrec .distributed .embedding_lookup import (
24
+ BatchedFusedEmbeddingBag ,
19
25
GroupedEmbeddingsLookup ,
20
26
GroupedPooledEmbeddingsLookup ,
21
27
)
26
32
EmbdUpdateMode ,
27
33
TrackingMode ,
28
34
)
35
+ from torchrec .distributed .utils import none_throws
36
+
29
37
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
30
38
31
39
UPDATE_MODE_MAP : Dict [TrackingMode , EmbdUpdateMode ] = {
42
50
# This mode supports approximate top-k delta-row selection, can be
43
51
# obtained by running momentum.norm().topk().
44
52
TrackingMode .MOMENTUM_LAST : EmbdUpdateMode .LAST ,
53
+ # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
54
+ # Within each publishing interval, we track the starting value of this running
55
+ # sum on all used rows and then do a lookup when ``get_delta`` is called to query
56
+ # the latest sum. Then we can compute the delta of the two values and return them
57
+ # together with the row ids.
58
+ TrackingMode .MOMENTUM_DIFF : EmbdUpdateMode .FIRST ,
59
+ # The same as MOMENTUM_DIFF. Adding for backward compatibility.
60
+ TrackingMode .ROWWISE_ADAGRAD : EmbdUpdateMode .FIRST ,
45
61
}
46
62
47
63
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -99,6 +115,7 @@ def __init__(
99
115
100
116
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
101
117
self .tracked_modules : Dict [str , nn .Module ] = {}
118
+ self .table_to_fqn : Dict [str , str ] = {}
102
119
self .feature_to_fqn : Dict [str , str ] = {}
103
120
# Generate the mapping from FQN to feature names.
104
121
self .fqn_to_feature_names ()
@@ -180,6 +197,11 @@ def record_lookup(
180
197
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181
198
elif self ._mode == TrackingMode .MOMENTUM_LAST :
182
199
self .record_momentum (emb_module , kjt )
200
+ elif (
201
+ self ._mode == TrackingMode .MOMENTUM_DIFF
202
+ or self ._mode == TrackingMode .ROWWISE_ADAGRAD
203
+ ):
204
+ self .record_rowwise_optim_state (emb_module , kjt )
183
205
else :
184
206
raise NotImplementedError (f"Tracking mode { self ._mode } is not supported" )
185
207
@@ -278,6 +300,60 @@ def record_momentum(
278
300
states = per_key_states ,
279
301
)
280
302
303
+ def record_rowwise_optim_state (
304
+ self ,
305
+ emb_module : nn .Module ,
306
+ kjt : KeyedJaggedTensor ,
307
+ ) -> None :
308
+ opt_states : List [List [torch .Tensor ]] = (
309
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
310
+ # `split_optimizer_states`.
311
+ emb_module ._emb_module .split_optimizer_states ()
312
+ )
313
+ proxy : torch .Tensor = torch .cat ([state [0 ] for state in opt_states ])
314
+ states = proxy [kjt .values ()]
315
+ assert (
316
+ kjt .values ().numel () == states .numel ()
317
+ ), f"number of ids and states mismatch, expect { kjt .values ()= } , { kjt .values ().numel ()} , but got { states .numel ()} "
318
+ offsets : torch .Tensor = torch .ops .fbgemm .asynchronous_complete_cumsum (
319
+ torch .tensor (kjt .length_per_key (), dtype = torch .int64 )
320
+ )
321
+ for i , key in enumerate (kjt .keys ()):
322
+ fqn = self .feature_to_fqn [key ]
323
+ per_key_states = states [offsets [i ] : offsets [i + 1 ]]
324
+ self .store .append (
325
+ batch_idx = self .curr_batch_idx ,
326
+ table_fqn = fqn ,
327
+ ids = kjt [key ].values (),
328
+ states = per_key_states ,
329
+ )
330
+
331
+ def get_latest (self ) -> Dict [str , torch .Tensor ]:
332
+ ret : Dict [str , torch .Tensor ] = {}
333
+ for module in self .tracked_modules .values ():
334
+ # pyre-fixme[29]:
335
+ for lookup in module ._lookups :
336
+ for embs_module in lookup ._emb_modules :
337
+ assert isinstance (
338
+ embs_module , (BatchedFusedEmbeddingBag , BatchedFusedEmbedding )
339
+ ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but { type (embs_module )} found"
340
+ tbe = embs_module ._emb_module
341
+
342
+ assert isinstance (tbe , SplitTableBatchedEmbeddingBagsCodegen )
343
+ table_names = [t .name for t in embs_module ._config .embedding_tables ]
344
+ opt_states = tbe .split_optimizer_states ()
345
+ assert len (table_names ) == len (opt_states )
346
+
347
+ for i , table_name in enumerate (table_names ):
348
+ emb_fqn = self .table_to_fqn [table_name ]
349
+ table_state = opt_states [i ][0 ]
350
+ assert (
351
+ emb_fqn not in ret
352
+ ), f"a table with { emb_fqn } already exists"
353
+ ret [emb_fqn ] = table_state
354
+
355
+ return ret
356
+
281
357
def get_delta_ids (self , consumer : Optional [str ] = None ) -> Dict [str , torch .Tensor ]:
282
358
"""
283
359
Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
289
365
per_table_delta_rows = self .get_delta (consumer )
290
366
return {fqn : delta_rows .ids for fqn , delta_rows in per_table_delta_rows .items ()}
291
367
292
- def get_delta (self , consumer : Optional [str ] = None ) -> Dict [str , DeltaRows ]:
368
+ def get_delta (
369
+ self ,
370
+ consumer : Optional [str ] = None ,
371
+ top_percentage : Optional [float ] = 1.0 ,
372
+ per_table_percentage : Optional [Dict [str , Tuple [float , str ]]] = None ,
373
+ sorted_by_indices : Optional [bool ] = True ,
374
+ ) -> Dict [str , DeltaRows ]:
293
375
"""
294
376
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
295
377
@@ -314,6 +396,65 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
314
396
self .per_consumer_batch_idx [consumer ] = index_end
315
397
if self ._delete_on_read :
316
398
self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
399
+
400
+ if self ._mode in (TrackingMode .MOMENTUM_DIFF , TrackingMode .ROWWISE_ADAGRAD ):
401
+ square_sum_map = self .get_latest ()
402
+ for fqn , rows in tracker_rows .items ():
403
+ assert (
404
+ fqn in square_sum_map
405
+ ), f"{ fqn } not found in { square_sum_map .keys ()} "
406
+ # compute delta sum_t(g^2) for t in [t1, t2] through
407
+ # sum_t2(g^2) - sum_t1(g^2)
408
+ # pyre-fixme[58]: `-` is not supported for operand types `Tensor`
409
+ # and `Optional[Tensor]`.
410
+ rows .states = square_sum_map [fqn ][rows .ids ] - rows .states
411
+
412
+ if rows .states is not None :
413
+ default_k = rows .states .size (- 1 )
414
+ top_k = (
415
+ int (top_percentage * default_k )
416
+ if top_percentage is not None
417
+ else default_k
418
+ )
419
+
420
+ if (
421
+ per_table_percentage is not None
422
+ and per_table_percentage .get (fqn ) is not None
423
+ ):
424
+ per_table_k = int (per_table_percentage [fqn ][0 ] * default_k )
425
+ policy = per_table_percentage [fqn ][1 ]
426
+
427
+ if policy == "MIN" :
428
+ top_k = min (top_k , per_table_k )
429
+ elif policy == "MAX" :
430
+ top_k = max (top_k , per_table_k )
431
+ elif policy == "OVERRIDE" :
432
+ top_k = per_table_k
433
+ else :
434
+ logger .warning (
435
+ f"Unknown policy { policy } , will keep using original top_k { top_k } "
436
+ )
437
+
438
+ logger .info (f"get_unique { fqn = } { top_k = } { default_k = } " )
439
+
440
+ if top_k >= default_k :
441
+ continue
442
+
443
+ if sorted_by_indices :
444
+ sorted_indices , _ = torch .sort (
445
+ torch .topk (
446
+ none_throws (rows .states ), top_k , sorted = False
447
+ ).indices ,
448
+ stable = False ,
449
+ )
450
+ rows .ids = rows .ids [sorted_indices ]
451
+ rows .states = none_throws (rows .states )[sorted_indices ]
452
+ else :
453
+ rows .states , indices = torch .topk (
454
+ none_throws (rows .states ), top_k , sorted = False
455
+ )
456
+ rows .ids = rows .ids [indices ]
457
+
317
458
return tracker_rows
318
459
319
460
def get_tracked_modules (self ) -> Dict [str , nn .Module ]:
@@ -330,7 +471,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
330
471
return self ._fqn_to_feature_map
331
472
332
473
table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
333
- table_to_fqn : Dict [str , str ] = OrderedDict ()
334
474
for fqn , named_module in self ._model .named_modules ():
335
475
split_fqn = fqn .split ("." )
336
476
# Skipping partial FQNs present in fqns_to_skip
@@ -356,13 +496,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
356
496
# will incorrectly match fqn with all the table names that have the same prefix
357
497
if table_name in split_fqn :
358
498
embedding_fqn = self ._clean_fqn_fn (fqn )
359
- if table_name in table_to_fqn :
499
+ if table_name in self . table_to_fqn :
360
500
# Sanity check for validating that we don't have more then one table mapping to same fqn.
361
501
logger .warning (
362
- f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
502
+ f"Override { self . table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
363
503
)
364
- table_to_fqn [table_name ] = embedding_fqn
365
- logger .info (f"Table to fqn: { table_to_fqn } " )
504
+ self . table_to_fqn [table_name ] = embedding_fqn
505
+ logger .info (f"Table to fqn: { self . table_to_fqn } " )
366
506
flatten_names = [
367
507
name for names in table_to_feature_names .values () for name in names
368
508
]
@@ -375,15 +515,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
375
515
376
516
fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
377
517
for table_name in table_to_feature_names :
378
- if table_name not in table_to_fqn :
518
+ if table_name not in self . table_to_fqn :
379
519
# This is likely unexpected, where we can't locate the FQN associated with this table.
380
520
logger .warning (
381
- f"Table { table_name } not found in { table_to_fqn } , skipping"
521
+ f"Table { table_name } not found in { self . table_to_fqn } , skipping"
382
522
)
383
523
continue
384
- fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
385
- table_name
386
- ]
524
+ fqn_to_feature_names [self . table_to_fqn [table_name ]] = (
525
+ table_to_feature_names [ table_name ]
526
+ )
387
527
self ._fqn_to_feature_map = fqn_to_feature_names
388
528
return fqn_to_feature_names
389
529
@@ -451,6 +591,24 @@ def _validate_and_init_tracker_fns(self) -> None:
451
591
(GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
452
592
)
453
593
lookup .register_optim_state_tracker_fn (self .record_lookup )
594
+ elif (
595
+ self ._mode == TrackingMode .ROWWISE_ADAGRAD
596
+ or self ._mode == TrackingMode .MOMENTUM_DIFF
597
+ ):
598
+ # pyre-ignore[29]:
599
+ for lookup in module ._lookups :
600
+ assert isinstance (
601
+ lookup ,
602
+ (GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
603
+ ) and all (
604
+ # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
605
+ # pyre-ignore[16]:
606
+ emb ._emb_module .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD
607
+ # pyre-ignore[16]:
608
+ or emb ._emb_module .optimizer == OptimType .PARTIAL_ROWWISE_ADAM
609
+ for emb in lookup ._emb_modules
610
+ )
611
+ lookup .register_optim_state_tracker_fn (self .record_lookup )
454
612
else :
455
613
raise NotImplementedError (
456
614
f"Tracking mode { self ._mode } is not supported"
0 commit comments