Commit 93cba72
Adding support for tracking optimizers states in Model Delta Tracker.
Summary:
### Overview
This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection.
### Key Changes
#### 1. Optimizer State Tracking Support
* To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules.
* Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable
* Tracking calls are invoked after each lookup operation.
#### 2. Model Delta Tracker Changes
* Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function.
* Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode
#### 3. New Tracking Mode
* Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py")
* Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values
Differential Revision: D768681111 parent 964f6be commit 93cba72
File tree
7 files changed
+279
-18
lines changed- torchrec/distributed
- model_tracker
- tests
7 files changed
+279
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1540 | 1540 | | |
1541 | 1541 | | |
1542 | 1542 | | |
1543 | | - | |
| 1543 | + | |
1544 | 1544 | | |
1545 | 1545 | | |
1546 | 1546 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
206 | 206 | | |
207 | 207 | | |
208 | 208 | | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
209 | 213 | | |
210 | 214 | | |
211 | 215 | | |
| |||
305 | 309 | | |
306 | 310 | | |
307 | 311 | | |
308 | | - | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
309 | 319 | | |
310 | 320 | | |
311 | 321 | | |
| |||
410 | 420 | | |
411 | 421 | | |
412 | 422 | | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
413 | 436 | | |
414 | 437 | | |
415 | 438 | | |
| |||
482 | 505 | | |
483 | 506 | | |
484 | 507 | | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
485 | 512 | | |
486 | 513 | | |
487 | 514 | | |
| |||
629 | 656 | | |
630 | 657 | | |
631 | 658 | | |
632 | | - | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
633 | 665 | | |
634 | 666 | | |
635 | 667 | | |
| |||
762 | 794 | | |
763 | 795 | | |
764 | 796 | | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
765 | 810 | | |
766 | 811 | | |
767 | 812 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
376 | 376 | | |
377 | 377 | | |
378 | 378 | | |
379 | | - | |
| 379 | + | |
380 | 380 | | |
381 | 381 | | |
382 | 382 | | |
| |||
429 | 429 | | |
430 | 430 | | |
431 | 431 | | |
432 | | - | |
| 432 | + | |
433 | 433 | | |
434 | 434 | | |
435 | 435 | | |
436 | 436 | | |
437 | 437 | | |
438 | 438 | | |
439 | | - | |
| 439 | + | |
440 | 440 | | |
441 | 441 | | |
442 | 442 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1646 | 1646 | | |
1647 | 1647 | | |
1648 | 1648 | | |
1649 | | - | |
| 1649 | + | |
1650 | 1650 | | |
1651 | 1651 | | |
1652 | 1652 | | |
| |||
Lines changed: 75 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
17 | 22 | | |
18 | 23 | | |
19 | 24 | | |
| |||
27 | 32 | | |
28 | 33 | | |
29 | 34 | | |
30 | | - | |
31 | | - | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
32 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
33 | 45 | | |
34 | 46 | | |
35 | 47 | | |
| |||
141 | 153 | | |
142 | 154 | | |
143 | 155 | | |
144 | | - | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
145 | 159 | | |
146 | 160 | | |
147 | 161 | | |
| |||
152 | 166 | | |
153 | 167 | | |
154 | 168 | | |
| 169 | + | |
155 | 170 | | |
156 | 171 | | |
157 | 172 | | |
| |||
162 | 177 | | |
163 | 178 | | |
164 | 179 | | |
165 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
166 | 183 | | |
167 | 184 | | |
168 | 185 | | |
| |||
228 | 245 | | |
229 | 246 | | |
230 | 247 | | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
231 | 281 | | |
232 | 282 | | |
233 | 283 | | |
| |||
380 | 430 | | |
381 | 431 | | |
382 | 432 | | |
| 433 | + | |
383 | 434 | | |
384 | 435 | | |
385 | 436 | | |
386 | 437 | | |
387 | | - | |
388 | | - | |
389 | | - | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
390 | 458 | | |
391 | 459 | | |
392 | 460 | | |
| |||
0 commit comments