Commit 9ca1b52
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (meta-pytorch#3148)
Summary:
X-link: pytorch/FBGEMM#4429
X-link: facebookresearch/FBGEMM#1495
# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.
# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs
Differential Revision: D776041581 parent da486e3 commit 9ca1b52
1 file changed
+11
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
954 | 954 | | |
955 | 955 | | |
956 | 956 | | |
957 | | - | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
958 | 968 | | |
959 | 969 | | |
960 | 970 | | |
| |||
0 commit comments