You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# always_save_context=True, # Optional, defaults to False
88
+
# always_save_context=False, # Optional, defaults to False
87
89
# write_thread_count=1, # Optional, defaults to 1
88
90
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
89
91
)
@@ -95,7 +97,7 @@ A complete recipe example that puts this all together can be found [here](http:/
95
97
96
98
Limitations:
97
99
98
-
1.Must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
100
+
1.You must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
99
101
Other strategies have not been tested.
100
102
1. Ensure that the `base_container` for ML Flashpoint is job-specific (i.e. has a job ID in it), and on some ramdisk path (e.g. tmpfs).
101
103
The job ID should be unique across jobs, but sticky (reused) when a job is interrupted and restarted/rescheduled (so it can recover from the latest checkpoint available for that particular job).
@@ -105,8 +107,92 @@ This reduces blocking time by avoiding duplicate work, at the cost of having a l
105
107
106
108
### Megatron-LM
107
109
108
-
Check out the `adapter/megatron` package.
110
+
Code: See the `ml_flashpoint.adapter.megatron` package.
111
+
112
+
The Megatron strategies depend on the PyTorch DCP implementations.
113
+
Below are instructions for setting up ML Flashpoint checkpointing, which you should configure alongside regular checkpointing to long-term storage.
114
+
115
+
#### Save Strategy
116
+
117
+
First create a `MemoryStorageWriter` instance as outlined in [PyTorch DCP](#pytorch-dcp).
118
+
Then use that to instantiate the Megatron save strategy.
119
+
120
+
```python
121
+
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
122
+
from ml_flashpoint.adapter.megatron.save_strategies import (
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `ml_flashpoint.adapter.nemo` package.
136
+
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
137
+
138
+
Use this strategy on a more frequent interval than your regular long-term storage checkpointing strategy.
139
+
140
+
#### Load Strategy
141
+
142
+
Instantiate the singleton `ReplicationManager` with a singleton `CheckpointObjectManager`, and make sure to `initialize()` the `ReplicationManager` before using it.
143
+
Also create an `MLFlashpointCheckpointLoader` with those dependencies, and use these instances to create the load strategy:
144
+
145
+
```python
146
+
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
147
+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
148
+
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
149
+
from ml_flashpoint.replication.replication_manager import ReplicationManager
0 commit comments