You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adding support for tracking optimizers states in Model Delta Tracker.
Summary:
### Overview
This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection.
### Key Changes
#### 1. Optimizer State Tracking Support
* To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules.
* Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable
* Tracking calls are invoked after each lookup operation.
#### 2. Model Delta Tracker Changes
* Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function.
* Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode
#### 3. New Tracking Mode
* Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py")
* Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values
Differential Revision: D76868111
0 commit comments