Skip to content

Commit 7235a00

Browse files
authored
[FlexCheckPoint]adapt fc to sharding stage3 (#76538)
* adapt fc to sharding stage3 * add test and fix bug * fix bug * fix bug * add test
1 parent 468171d commit 7235a00

File tree

4 files changed

+324
-3
lines changed

4 files changed

+324
-3
lines changed

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py

Lines changed: 201 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
from paddle.autograd import PyLayer
2525
from paddle.base.framework import EagerParamBase
2626
from paddle.distributed import collective
27+
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import (
28+
ShardedStateDict,
29+
ShardedWeight,
30+
create_sharded_weight_with_new_local,
31+
)
2732
from paddle.framework import core
2833
from paddle.nn import ClipGradByGlobalNorm
2934

@@ -182,7 +187,11 @@ def __init__(
182187
"Multiple optimizers are not supported now."
183188
)
184189
self._optim = _OptimizerWrapper(
185-
optimizer, self._offload, self._group, self._update_params_slice
190+
optimizer,
191+
self._offload,
192+
self._group,
193+
self._update_params_slice,
194+
self._sharded_state_dict,
186195
)
187196
self._ori_parameter_list = self._optim._parameter_list
188197
self._ori_param_groups = self._optim._param_groups
@@ -850,6 +859,193 @@ def _opt_clear(self):
850859

851860
self._optim.clear_grad = MethodType(_opt_clear, self._optim)
852861

862+
def init_slice_param(self):
863+
for layer_id, params in self._trainable_params.items():
864+
for param in params:
865+
value = paddle.zeros(param.shape, dtype=param.dtype)
866+
value._share_buffer_to(param)
867+
868+
def align_param_to_buffer_and_clear_slice_param(self):
869+
for layer_id, params in self._trainable_params.items():
870+
for param in params:
871+
param_shape = param.shape
872+
origin_state = param.stop_gradient
873+
param.stop_gradient = True
874+
start, end = self._param2buffer[param.name][self._rank]
875+
param.flatten_()
876+
param.stop_gradient = origin_state
877+
param_numel = param.numel().item()
878+
start = min(start, param_numel)
879+
end = min(end, param_numel)
880+
if end > start:
881+
tmp_tensor = param._slice(start, end).detach()
882+
buffer_slice = param.fw_storage._slice(
883+
0, end - start
884+
).detach()
885+
buffer_slice.set_value(tmp_tensor)
886+
del buffer_slice
887+
param.get_tensor()._set_dims(param_shape)
888+
param._clear_data()
889+
890+
def init_optimizer_for_slice_param(self):
891+
local_param_list = []
892+
for param in self._optim._parameter_list:
893+
if hasattr(param, "fw_storage"):
894+
var = param.fw_storage
895+
tmp_param = EagerParamBase(
896+
shape=var.shape, dtype=var.dtype, name="slice@" + param.name
897+
)
898+
local_param_list.append(tmp_param)
899+
else:
900+
local_param_list.append(param)
901+
self._optim._parameter_list = local_param_list
902+
903+
def _sharded_state_dict(
904+
self,
905+
model_sharded_state_dict: ShardedStateDict,
906+
) -> ShardedStateDict:
907+
"""
908+
Convert optimizer state dict to a sharded state dict based on model sharding information.
909+
910+
Args:
911+
model_sharded_state_dict (dict): Sharded state dict of the model, containing tensor metadata.
912+
913+
Returns:
914+
dict: A new optimizer state dict where weights are wrapped as ShardedWeight.
915+
"""
916+
917+
_FP32_MASTER = "fp32_master_0"
918+
_MOMENT_NAME = "moment"
919+
_optimizer_scalar_name = [
920+
"beta1_pow_acc_0",
921+
"beta2_pow_acc_0",
922+
]
923+
_optimizer_non_scaler_name = [
924+
"moment1_0",
925+
"moment2_0",
926+
"velocity_0",
927+
]
928+
929+
param_to_slice = {}
930+
for param in self._ori_parameter_list:
931+
if hasattr(param, "fw_storage"):
932+
param_to_slice[param.name] = True
933+
else:
934+
param_to_slice[param.name] = False
935+
936+
def _create_sharded_weight(
937+
unified_name, tensor, sharded_param, static_name
938+
):
939+
if param_to_slice[static_name]:
940+
padding_begin = sharded_param.local_tensor.numel().item()
941+
slice_begin = min(padding_begin, self._rank * tensor.shape[0])
942+
slice_end = min(
943+
padding_begin, (self._rank + 1) * tensor.shape[0]
944+
)
945+
if slice_begin == padding_begin or slice_end == padding_begin:
946+
local_tensor = paddle.slice(
947+
tensor,
948+
axes=[0],
949+
starts=[0],
950+
ends=[slice_end - slice_begin],
951+
)
952+
else:
953+
local_tensor = tensor
954+
return ShardedWeight(
955+
key=unified_name,
956+
local_tensor=local_tensor,
957+
local_shape=sharded_param.local_shape,
958+
global_shape=sharded_param.global_shape,
959+
global_offset=sharded_param.global_offset,
960+
is_flattened=True,
961+
flattened_range=slice(slice_begin, slice_end),
962+
)
963+
else:
964+
return create_sharded_weight_with_new_local(
965+
unified_name, tensor, sharded_param
966+
)
967+
968+
def _generate_base_static_name(vname):
969+
if _FP32_MASTER in vname:
970+
return tuple(vname.split("_" + _FP32_MASTER + "_", 1))
971+
for name in _optimizer_scalar_name + _optimizer_non_scaler_name:
972+
if vname.endswith(name):
973+
return vname[: -(len(name) + 1)], name
974+
raise ValueError(f"Cannot split variable name: {vname}.")
975+
976+
optimizer_sharded_state_dict = {}
977+
optimizer_state_dict = self._optim.state_dict()
978+
# Build name mapping and remove non-tensor entries from optimizer state
979+
static_to_struct_mapping = {}
980+
model_sharded_state_dict = dict(
981+
sorted(model_sharded_state_dict.items())
982+
)
983+
for k, v in model_sharded_state_dict.items():
984+
# When shared weights exist, the v.local_tensor.name of shared parameters are identical, but only the first parameter has optimizer states. Therefore, only the key-value pairs of the first occurrence in the shared parameter group need to be retained.
985+
if v.local_tensor.name not in static_to_struct_mapping:
986+
static_to_struct_mapping[v.local_tensor.name] = k
987+
988+
master_weights = optimizer_state_dict.pop("master_weights", None)
989+
optimizer_state_dict.pop("LR_Scheduler", None)
990+
# Process main optimizer states
991+
for key, tensor in optimizer_state_dict.items():
992+
static_name, optim_state_type = _generate_base_static_name(key)
993+
static_name = static_name.replace("slice@", "")
994+
struct_name = static_to_struct_mapping[static_name]
995+
sharded_weight = model_sharded_state_dict[struct_name]
996+
997+
unified_name = f"{struct_name}.{optim_state_type}"
998+
999+
# Determine tensor partitioning scheme
1000+
if _MOMENT_NAME in optim_state_type:
1001+
if tensor.is_dist():
1002+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
1003+
key=unified_name,
1004+
local_tensor=tensor,
1005+
local_shape=tensor.shape,
1006+
global_shape=tensor.shape,
1007+
global_offset=sharded_weight.global_offset,
1008+
)
1009+
else:
1010+
optimizer_sharded_state_dict[unified_name] = (
1011+
_create_sharded_weight(
1012+
unified_name, tensor, sharded_weight, static_name
1013+
)
1014+
)
1015+
1016+
else: # Non-momentum parameters
1017+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
1018+
key=unified_name,
1019+
local_tensor=tensor,
1020+
local_shape=(1,),
1021+
global_shape=(1,),
1022+
global_offset=(0,),
1023+
)
1024+
1025+
# Process master weights if using mixed precision
1026+
if master_weights is not None:
1027+
for key, tensor in master_weights.items():
1028+
key = key.replace("slice@", "")
1029+
struct_name = static_to_struct_mapping[key]
1030+
sharded_weight = model_sharded_state_dict[struct_name]
1031+
unified_name = f"{struct_name}.w_0"
1032+
if tensor.is_dist():
1033+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
1034+
key=unified_name,
1035+
local_tensor=tensor,
1036+
local_shape=tensor.shape,
1037+
global_shape=tensor.shape,
1038+
global_offset=sharded_weight.global_offset,
1039+
)
1040+
else:
1041+
optimizer_sharded_state_dict[unified_name] = (
1042+
_create_sharded_weight(
1043+
unified_name, tensor, sharded_weight, key
1044+
)
1045+
)
1046+
1047+
return optimizer_sharded_state_dict
1048+
8531049

8541050
def ForwardPreHooks(
8551051
layer,
@@ -1200,13 +1396,16 @@ def _TensorWrapper(param):
12001396
return tmp_param
12011397

12021398

1203-
def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
1399+
def _OptimizerWrapper(
1400+
optimizer, offload, group, update_params_slice, sharded_state_dict
1401+
):
12041402
if not hasattr(optimizer, "_optim"):
12051403
optimizer._optim = optimizer
12061404
optimizer.offload = offload
12071405
optimizer._group = group
12081406
optimizer.update_scaler = None
12091407
optimizer.update_slice = update_params_slice
1408+
optimizer.sharded_state_dict = sharded_state_dict
12101409
return optimizer
12111410

12121411

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def _handle_aoa(
716716
src_desc_to_postprocess_list = {}
717717
force_gc = []
718718

719-
for param_name, tgt_shard in load_dict.items():
719+
for param_name, tgt_shard in sorted(load_dict.items()):
720720
tgt_desc = build_shard_desc(tgt_shard)
721721
shard_mappings = aoa_engine.find_shard_sources(tgt_desc)
722722
for mapping in shard_mappings:

test/flex_checkpoint/sharded_state_dict_logic.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
3131
GroupShardedOptimizerStage2,
3232
)
33+
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import (
34+
GroupShardedStage3,
35+
)
3336
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
3437
ColumnSequenceParallelLinear,
3538
RowSequenceParallelLinear,
@@ -410,6 +413,116 @@ def run_optimizer_test(self):
410413
assert tuple(
411414
opt_sharded_state_dict[opt__var_name].global_offset
412415
) == tuple(value.global_offset)
416+
417+
elif self.layer_type == "GroupShardedStage3":
418+
model = fleet.distributed_model(model)
419+
wrapped_model = GroupShardedStage3(
420+
model, opt, segment_size=2**12
421+
) # slice the linear1、linear2 weight
422+
for param in opt._parameter_list:
423+
if hasattr(param, "fw_storage"):
424+
assert len(param.shape) != 1
425+
wrapped_model.init_optimizer_for_slice_param()
426+
for param in opt._parameter_list:
427+
if hasattr(param, "fw_storage"):
428+
assert len(param.shape) == 1
429+
model_sharded_state_dict = model.sharded_state_dict()
430+
for k, v in model_sharded_state_dict.items():
431+
if (
432+
k == "_layers.linear1.weight"
433+
or k == "_layers.linear2.weight"
434+
):
435+
assert not v.local_tensor._is_initialized()
436+
wrapped_model.init_slice_param()
437+
for k, v in model_sharded_state_dict.items():
438+
if (
439+
k == "_layers.linear1.weight"
440+
or k == "_layers.linear2.weight"
441+
):
442+
assert v.local_tensor._is_initialized()
443+
wrapped_model.align_param_to_buffer_and_clear_slice_param()
444+
for k, v in model_sharded_state_dict.items():
445+
if (
446+
k == "_layers.linear1.weight"
447+
or k == "_layers.linear2.weight"
448+
):
449+
assert not v.local_tensor._is_initialized()
450+
model.train()
451+
x = paddle.randint(
452+
low=0,
453+
high=self.vocab_size,
454+
shape=[self.batch_size, self.seq_len, self.hidden_size],
455+
dtype='int64',
456+
)
457+
rank = paddle.distributed.get_rank()
458+
sharidng_x = (
459+
x[0 : self.batch_size // 2]
460+
if rank == 0
461+
else x[self.batch_size // 2 :]
462+
)
463+
y = model(sharidng_x).mean()
464+
y.backward()
465+
opt.step()
466+
opt.clear_grad()
467+
model_sharded_state_dict = model.sharded_state_dict()
468+
for k, v in model_sharded_state_dict.items():
469+
if (
470+
k == "_layers.linear1.weight"
471+
or k == "_layers.linear2.weight"
472+
):
473+
assert not v.local_tensor._is_initialized()
474+
wrapped_model.get_all_parameters()
475+
opt_sharded_state_dict = opt.sharded_state_dict(
476+
model_sharded_state_dict
477+
)
478+
479+
for k, v in model_sharded_state_dict.items():
480+
if (
481+
k == "_layers.linear1.weight"
482+
or k == "_layers.linear2.weight"
483+
):
484+
assert v.local_tensor._is_initialized()
485+
486+
for key, value in model_sharded_state_dict.items():
487+
for state_name in self.optimizer_var_suffix:
488+
opt__var_name = key + state_name
489+
if opt__var_name in opt_sharded_state_dict:
490+
if hasattr(
491+
value.local_tensor, "fw_storage"
492+
): # check the optimizer_var which isFragment
493+
opt_var_globle_flattened_range = []
494+
paddle.distributed.all_gather_object(
495+
opt_var_globle_flattened_range,
496+
opt_sharded_state_dict[
497+
opt__var_name
498+
].flattened_range,
499+
)
500+
501+
first_fragment = opt_var_globle_flattened_range[0]
502+
second_fragment = opt_var_globle_flattened_range[1]
503+
assert (
504+
first_fragment.stop == second_fragment.start
505+
) # the first_flattened_range_stop == the second_flattened_range_start
506+
opt_var_globle_size_flattened = (
507+
second_fragment.stop - first_fragment.start
508+
)
509+
model_var_globle_size_flattened = math.prod(
510+
value.local_shape
511+
)
512+
assert (
513+
opt_var_globle_size_flattened
514+
== model_var_globle_size_flattened
515+
)
516+
517+
assert tuple(
518+
opt_sharded_state_dict[opt__var_name].local_shape
519+
) == tuple(value.local_shape)
520+
assert tuple(
521+
opt_sharded_state_dict[opt__var_name].global_shape
522+
) == tuple(value.global_shape)
523+
assert tuple(
524+
opt_sharded_state_dict[opt__var_name].global_offset
525+
) == tuple(value.global_offset)
413526
else:
414527
raise ValueError(f"Unknown layer_type: {self.layer_type}")
415528

test/flex_checkpoint/test_sharded_state_dict.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@
141141
"has_bias": "True",
142142
"master_weight": "True",
143143
},
144+
{
145+
"test_type": "optimizer",
146+
"layer_type": "GroupShardedStage3",
147+
"world_size": 2,
148+
"tp": 1,
149+
"sharding_degree": 2,
150+
"has_bias": "True",
151+
"master_weight": "True",
152+
},
144153
],
145154
"4_card_tests": [
146155
{

0 commit comments

Comments
 (0)