|
24 | 24 | from paddle.autograd import PyLayer |
25 | 25 | from paddle.base.framework import EagerParamBase |
26 | 26 | from paddle.distributed import collective |
| 27 | +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( |
| 28 | + ShardedStateDict, |
| 29 | + ShardedWeight, |
| 30 | + create_sharded_weight_with_new_local, |
| 31 | +) |
27 | 32 | from paddle.framework import core |
28 | 33 | from paddle.nn import ClipGradByGlobalNorm |
29 | 34 |
|
@@ -182,7 +187,11 @@ def __init__( |
182 | 187 | "Multiple optimizers are not supported now." |
183 | 188 | ) |
184 | 189 | self._optim = _OptimizerWrapper( |
185 | | - optimizer, self._offload, self._group, self._update_params_slice |
| 190 | + optimizer, |
| 191 | + self._offload, |
| 192 | + self._group, |
| 193 | + self._update_params_slice, |
| 194 | + self._sharded_state_dict, |
186 | 195 | ) |
187 | 196 | self._ori_parameter_list = self._optim._parameter_list |
188 | 197 | self._ori_param_groups = self._optim._param_groups |
@@ -850,6 +859,193 @@ def _opt_clear(self): |
850 | 859 |
|
851 | 860 | self._optim.clear_grad = MethodType(_opt_clear, self._optim) |
852 | 861 |
|
| 862 | + def init_slice_param(self): |
| 863 | + for layer_id, params in self._trainable_params.items(): |
| 864 | + for param in params: |
| 865 | + value = paddle.zeros(param.shape, dtype=param.dtype) |
| 866 | + value._share_buffer_to(param) |
| 867 | + |
| 868 | + def align_param_to_buffer_and_clear_slice_param(self): |
| 869 | + for layer_id, params in self._trainable_params.items(): |
| 870 | + for param in params: |
| 871 | + param_shape = param.shape |
| 872 | + origin_state = param.stop_gradient |
| 873 | + param.stop_gradient = True |
| 874 | + start, end = self._param2buffer[param.name][self._rank] |
| 875 | + param.flatten_() |
| 876 | + param.stop_gradient = origin_state |
| 877 | + param_numel = param.numel().item() |
| 878 | + start = min(start, param_numel) |
| 879 | + end = min(end, param_numel) |
| 880 | + if end > start: |
| 881 | + tmp_tensor = param._slice(start, end).detach() |
| 882 | + buffer_slice = param.fw_storage._slice( |
| 883 | + 0, end - start |
| 884 | + ).detach() |
| 885 | + buffer_slice.set_value(tmp_tensor) |
| 886 | + del buffer_slice |
| 887 | + param.get_tensor()._set_dims(param_shape) |
| 888 | + param._clear_data() |
| 889 | + |
| 890 | + def init_optimizer_for_slice_param(self): |
| 891 | + local_param_list = [] |
| 892 | + for param in self._optim._parameter_list: |
| 893 | + if hasattr(param, "fw_storage"): |
| 894 | + var = param.fw_storage |
| 895 | + tmp_param = EagerParamBase( |
| 896 | + shape=var.shape, dtype=var.dtype, name="slice@" + param.name |
| 897 | + ) |
| 898 | + local_param_list.append(tmp_param) |
| 899 | + else: |
| 900 | + local_param_list.append(param) |
| 901 | + self._optim._parameter_list = local_param_list |
| 902 | + |
| 903 | + def _sharded_state_dict( |
| 904 | + self, |
| 905 | + model_sharded_state_dict: ShardedStateDict, |
| 906 | + ) -> ShardedStateDict: |
| 907 | + """ |
| 908 | + Convert optimizer state dict to a sharded state dict based on model sharding information. |
| 909 | +
|
| 910 | + Args: |
| 911 | + model_sharded_state_dict (dict): Sharded state dict of the model, containing tensor metadata. |
| 912 | +
|
| 913 | + Returns: |
| 914 | + dict: A new optimizer state dict where weights are wrapped as ShardedWeight. |
| 915 | + """ |
| 916 | + |
| 917 | + _FP32_MASTER = "fp32_master_0" |
| 918 | + _MOMENT_NAME = "moment" |
| 919 | + _optimizer_scalar_name = [ |
| 920 | + "beta1_pow_acc_0", |
| 921 | + "beta2_pow_acc_0", |
| 922 | + ] |
| 923 | + _optimizer_non_scaler_name = [ |
| 924 | + "moment1_0", |
| 925 | + "moment2_0", |
| 926 | + "velocity_0", |
| 927 | + ] |
| 928 | + |
| 929 | + param_to_slice = {} |
| 930 | + for param in self._ori_parameter_list: |
| 931 | + if hasattr(param, "fw_storage"): |
| 932 | + param_to_slice[param.name] = True |
| 933 | + else: |
| 934 | + param_to_slice[param.name] = False |
| 935 | + |
| 936 | + def _create_sharded_weight( |
| 937 | + unified_name, tensor, sharded_param, static_name |
| 938 | + ): |
| 939 | + if param_to_slice[static_name]: |
| 940 | + padding_begin = sharded_param.local_tensor.numel().item() |
| 941 | + slice_begin = min(padding_begin, self._rank * tensor.shape[0]) |
| 942 | + slice_end = min( |
| 943 | + padding_begin, (self._rank + 1) * tensor.shape[0] |
| 944 | + ) |
| 945 | + if slice_begin == padding_begin or slice_end == padding_begin: |
| 946 | + local_tensor = paddle.slice( |
| 947 | + tensor, |
| 948 | + axes=[0], |
| 949 | + starts=[0], |
| 950 | + ends=[slice_end - slice_begin], |
| 951 | + ) |
| 952 | + else: |
| 953 | + local_tensor = tensor |
| 954 | + return ShardedWeight( |
| 955 | + key=unified_name, |
| 956 | + local_tensor=local_tensor, |
| 957 | + local_shape=sharded_param.local_shape, |
| 958 | + global_shape=sharded_param.global_shape, |
| 959 | + global_offset=sharded_param.global_offset, |
| 960 | + is_flattened=True, |
| 961 | + flattened_range=slice(slice_begin, slice_end), |
| 962 | + ) |
| 963 | + else: |
| 964 | + return create_sharded_weight_with_new_local( |
| 965 | + unified_name, tensor, sharded_param |
| 966 | + ) |
| 967 | + |
| 968 | + def _generate_base_static_name(vname): |
| 969 | + if _FP32_MASTER in vname: |
| 970 | + return tuple(vname.split("_" + _FP32_MASTER + "_", 1)) |
| 971 | + for name in _optimizer_scalar_name + _optimizer_non_scaler_name: |
| 972 | + if vname.endswith(name): |
| 973 | + return vname[: -(len(name) + 1)], name |
| 974 | + raise ValueError(f"Cannot split variable name: {vname}.") |
| 975 | + |
| 976 | + optimizer_sharded_state_dict = {} |
| 977 | + optimizer_state_dict = self._optim.state_dict() |
| 978 | + # Build name mapping and remove non-tensor entries from optimizer state |
| 979 | + static_to_struct_mapping = {} |
| 980 | + model_sharded_state_dict = dict( |
| 981 | + sorted(model_sharded_state_dict.items()) |
| 982 | + ) |
| 983 | + for k, v in model_sharded_state_dict.items(): |
| 984 | + # When shared weights exist, the v.local_tensor.name of shared parameters are identical, but only the first parameter has optimizer states. Therefore, only the key-value pairs of the first occurrence in the shared parameter group need to be retained. |
| 985 | + if v.local_tensor.name not in static_to_struct_mapping: |
| 986 | + static_to_struct_mapping[v.local_tensor.name] = k |
| 987 | + |
| 988 | + master_weights = optimizer_state_dict.pop("master_weights", None) |
| 989 | + optimizer_state_dict.pop("LR_Scheduler", None) |
| 990 | + # Process main optimizer states |
| 991 | + for key, tensor in optimizer_state_dict.items(): |
| 992 | + static_name, optim_state_type = _generate_base_static_name(key) |
| 993 | + static_name = static_name.replace("slice@", "") |
| 994 | + struct_name = static_to_struct_mapping[static_name] |
| 995 | + sharded_weight = model_sharded_state_dict[struct_name] |
| 996 | + |
| 997 | + unified_name = f"{struct_name}.{optim_state_type}" |
| 998 | + |
| 999 | + # Determine tensor partitioning scheme |
| 1000 | + if _MOMENT_NAME in optim_state_type: |
| 1001 | + if tensor.is_dist(): |
| 1002 | + optimizer_sharded_state_dict[unified_name] = ShardedWeight( |
| 1003 | + key=unified_name, |
| 1004 | + local_tensor=tensor, |
| 1005 | + local_shape=tensor.shape, |
| 1006 | + global_shape=tensor.shape, |
| 1007 | + global_offset=sharded_weight.global_offset, |
| 1008 | + ) |
| 1009 | + else: |
| 1010 | + optimizer_sharded_state_dict[unified_name] = ( |
| 1011 | + _create_sharded_weight( |
| 1012 | + unified_name, tensor, sharded_weight, static_name |
| 1013 | + ) |
| 1014 | + ) |
| 1015 | + |
| 1016 | + else: # Non-momentum parameters |
| 1017 | + optimizer_sharded_state_dict[unified_name] = ShardedWeight( |
| 1018 | + key=unified_name, |
| 1019 | + local_tensor=tensor, |
| 1020 | + local_shape=(1,), |
| 1021 | + global_shape=(1,), |
| 1022 | + global_offset=(0,), |
| 1023 | + ) |
| 1024 | + |
| 1025 | + # Process master weights if using mixed precision |
| 1026 | + if master_weights is not None: |
| 1027 | + for key, tensor in master_weights.items(): |
| 1028 | + key = key.replace("slice@", "") |
| 1029 | + struct_name = static_to_struct_mapping[key] |
| 1030 | + sharded_weight = model_sharded_state_dict[struct_name] |
| 1031 | + unified_name = f"{struct_name}.w_0" |
| 1032 | + if tensor.is_dist(): |
| 1033 | + optimizer_sharded_state_dict[unified_name] = ShardedWeight( |
| 1034 | + key=unified_name, |
| 1035 | + local_tensor=tensor, |
| 1036 | + local_shape=tensor.shape, |
| 1037 | + global_shape=tensor.shape, |
| 1038 | + global_offset=sharded_weight.global_offset, |
| 1039 | + ) |
| 1040 | + else: |
| 1041 | + optimizer_sharded_state_dict[unified_name] = ( |
| 1042 | + _create_sharded_weight( |
| 1043 | + unified_name, tensor, sharded_weight, key |
| 1044 | + ) |
| 1045 | + ) |
| 1046 | + |
| 1047 | + return optimizer_sharded_state_dict |
| 1048 | + |
853 | 1049 |
|
854 | 1050 | def ForwardPreHooks( |
855 | 1051 | layer, |
@@ -1200,13 +1396,16 @@ def _TensorWrapper(param): |
1200 | 1396 | return tmp_param |
1201 | 1397 |
|
1202 | 1398 |
|
1203 | | -def _OptimizerWrapper(optimizer, offload, group, update_params_slice): |
| 1399 | +def _OptimizerWrapper( |
| 1400 | + optimizer, offload, group, update_params_slice, sharded_state_dict |
| 1401 | +): |
1204 | 1402 | if not hasattr(optimizer, "_optim"): |
1205 | 1403 | optimizer._optim = optimizer |
1206 | 1404 | optimizer.offload = offload |
1207 | 1405 | optimizer._group = group |
1208 | 1406 | optimizer.update_scaler = None |
1209 | 1407 | optimizer.update_slice = update_params_slice |
| 1408 | + optimizer.sharded_state_dict = sharded_state_dict |
1210 | 1409 | return optimizer |
1211 | 1410 |
|
1212 | 1411 |
|
|
0 commit comments