diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 5c657db4f270..4f021f05136a 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -6,4 +6,4 @@ from .cpu_adam import DeepSpeedCPUAdam from .fused_adam import FusedAdam from .zenflow_cpu_adam import ZenFlowCPUAdam -from .zenflow_torch_adam import ZenFlowSelectiveAdamW +from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py index 4f55c58c81c9..1d55210d6edc 100644 --- a/deepspeed/ops/adam/zenflow_torch_adam.py +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -53,30 +53,20 @@ def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): if offload: self.step = self._step_with_offload - self.temp_copy_param = self._temp_copy_param_with_offload - self.group_step = self._group_step_with_offload self.bucket_size = bucket_size else: self.step = self._step_without_offload - self.temp_copy_param = self._temp_copy_param_without_offload - self.group_step = self._group_step_without_offload - @torch.no_grad() - def _temp_copy_param_with_offload(self, group_to_paramlist): + def temp_copy_param(self, group_to_paramlist): for group_id, params in group_to_paramlist.items(): for param in params: if hasattr(param, "selected_grad"): temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len( param.shape) != 1 else param.data.clone().detach() - param.temp_selected_param = temp_selected_param.cpu() - - @torch.no_grad() - def _temp_copy_param_without_offload(self, group_to_paramlist): - for group_id, params in group_to_paramlist.items(): - for param in params: - if hasattr(param, "selected_grad"): - param.temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len( - param.shape) != 1 else param.data.clone().detach() + if self.offload: + param.temp_selected_param = temp_selected_param.cpu() + else: + param.temp_selected_param = temp_selected_param def copy_mv_from_cpu(self, params): for param in params: @@ -167,6 +157,13 @@ def _step_without_offload(self): @torch.no_grad() def _step_with_offload(self): + """ + Performs parameter updates in offload mode. + + In this mode, group_step() calls adamw() on each pre-partitioned param bucket, + so memory can be released after each bucket update to reduce GPU overhead. + Without offload, adamw() is called directly for speed. + """ for group_id, group in enumerate(self.param_groups): params = group["params"] @@ -197,35 +194,48 @@ def flush_bucket(): flush_bucket() @torch.no_grad() - def _group_step_without_offload(self, group_to_paramlist): + def group_step(self, group_to_paramlist): for group_id, params in group_to_paramlist.items(): group = self.param_groups[group_id] + if self.offload: + self.copy_mv_from_cpu(params) + params_with_grad: List[Tensor] = [] grads: List[Tensor] = [] exp_avgs: List[Tensor] = [] exp_avg_sqs: List[Tensor] = [] max_exp_avg_sqs: List[Tensor] = [] state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] beta1, beta2 = cast(Tuple[float, float], group["betas"]) for param in params: if hasattr(param, "selected_grad"): - selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + is_2d = (len(param.shape) != 1) + selected_param = param.data[:, param.selected_indices] if is_2d else param.data state = self.state.setdefault(param, {}) if len(state) == 0: state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) - state["exp_avg"] = torch.zeros_like(selected_param) - state["exp_avg_sq"] = torch.zeros_like(selected_param) if amsgrad: state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + if not self.offload: + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + + if self.offload: + exp_avg_t = param.exp_avg.view_as(selected_param) + exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param) + else: + exp_avg_t = state["exp_avg"] + exp_avg_sq_t = state["exp_avg_sq"] params_with_grad.append(selected_param) grads.append(param.selected_grad) - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) + exp_avgs.append(exp_avg_t) + exp_avg_sqs.append(exp_avg_sq_t) if amsgrad: max_exp_avg_sqs.append(state["max_exp_avg_sq"]) state_steps.append(state["step"]) @@ -247,42 +257,200 @@ def _group_step_without_offload(self, group_to_paramlist): ) for i, param in enumerate(params): - if hasattr(param, "selected_grad"): - if len(param.shape) != 1: - param.data[:, param.selected_indices] = params_with_grad[i] + if hasattr(param, "selected_grad") and len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] + + if self.offload: + self.copy_mv_to_cpu(params) for param in params: param.selected_grad = None + +class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW): + + def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): + super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs) + self.offload = offload + + if offload: + self.step = self._step_with_offload + self.bucket_size = bucket_size + else: + self.step = self._step_without_offload + @torch.no_grad() - def _group_step_with_offload(self, group_to_paramlist): - for group_id, params in group_to_paramlist.items(): + def temp_copy_param(self, paramlist): + for param in paramlist: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view( + param.complete_numel // num_row, num_row) + temp_selected_param = param_2d[param.selected_indices, :].clone().detach() + else: + temp_selected_param = param.ds_tensor.data.clone().detach() + + if self.offload: + param.temp_selected_param = temp_selected_param.cpu() + else: + param.temp_selected_param = temp_selected_param + + def clear_selected_mv(self): + print("Zenflow: clearing selective optimizer states...") + for group in self.param_groups: + for param in group['params']: + state = self.state.setdefault(param, {}) + if len(state) == 0: + continue + if self.offload: + param.exp_avg_cpu_data.zero_() + param.exp_avg_sq_cpu_data.zero_() + else: + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + + @torch.no_grad() + def _step_without_offload(self): + for group in self.param_groups: + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + for param in group["params"]: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data + if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: + selected_param.copy_(param.temp_selected_param) + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + for i, param in enumerate(group["params"]): + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = params_with_grad[i] + + for param in group["params"]: + if hasattr(param, "temp_selected_param"): + param.temp_selected_param = None + param.selected_grad = None + + def copy_mv_from_cpu(self, params): + for param in params: + param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True) + param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True) + + def copy_mv_to_cpu(self, params): + for param in params: + param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True) + param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True) + param.exp_avg = None + param.exp_avg_sq = None + + @torch.no_grad() + def group_step(self, paramlist): + + group_to_paramlist = {} + for param in paramlist: + group_id = param.group_id + if group_id not in group_to_paramlist: + group_to_paramlist[group_id] = [] + group_to_paramlist[group_id].append(param) + + for group_id in sorted(group_to_paramlist.keys()): + params = group_to_paramlist[group_id] group = self.param_groups[group_id] - self.copy_mv_from_cpu(params) + if self.offload: + self.copy_mv_from_cpu(params) + params_with_grad: List[Tensor] = [] grads: List[Tensor] = [] exp_avgs: List[Tensor] = [] exp_avg_sqs: List[Tensor] = [] max_exp_avg_sqs: List[Tensor] = [] state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] beta1, beta2 = cast(Tuple[float, float], group["betas"]) for param in params: if hasattr(param, "selected_grad"): - selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data state = self.state.setdefault(param, {}) if len(state) == 0: state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) if amsgrad: state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + if not self.offload: + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + + if self.offload: + exp_avg_t = param.exp_avg.view_as(selected_param) + exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param) + else: + exp_avg_t = state["exp_avg"] + exp_avg_sq_t = state["exp_avg_sq"] params_with_grad.append(selected_param) grads.append(param.selected_grad) - exp_avgs.append(param.exp_avg.view_as(selected_param)) - exp_avg_sqs.append(param.exp_avg_sq.view_as(selected_param)) + exp_avgs.append(exp_avg_t) + exp_avg_sqs.append(exp_avg_sq_t) if amsgrad: max_exp_avg_sqs.append(state["max_exp_avg_sq"]) state_steps.append(state["step"]) @@ -305,14 +473,64 @@ def _group_step_with_offload(self, group_to_paramlist): for i, param in enumerate(params): if hasattr(param, "selected_grad"): - if len(param.shape) != 1: - param.data[:, param.selected_indices] = params_with_grad[i] + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = params_with_grad[i] - self.copy_mv_to_cpu(params) + if self.offload: + self.copy_mv_to_cpu(params) for param in params: param.selected_grad = None + @torch.no_grad() + def _step_with_offload(self): + """ + Performs parameter updates in offload mode. + + In this mode, group_step() calls adamw() on each pre-partitioned param bucket, + so memory can be released after each bucket update to reduce GPU overhead. + Without offload, adamw() is called directly for speed. + """ + + for group_id, group in enumerate(self.param_groups): + params = group["params"] + + bucket = [] + bucket_numel = 0 + + def flush_bucket(): + if not bucket: + return + for param in bucket: + if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None: + temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True) + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = temp_selected_param + else: + param.ds_tensor.data.copy_(temp_selected_param) + param.temp_selected_param = None + + self.group_step(bucket) + bucket.clear() + + for param in params: + if hasattr(param, "selected_grad"): + bucket.append(param) + bucket_numel += param.numel() + if bucket_numel >= self.bucket_size: + flush_bucket() + bucket_numel = 0 + + flush_bucket() + def _single_tensor_adamw( params: List[Tensor], diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ea7c2e9d3d62..6ca9542f1a42 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1868,6 +1868,7 @@ def _configure_zero_optimizer(self, optimizer): overlap_comm=self.zero_overlap_comm(), offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), + zenflow_config=self.zenflow_config(), sub_group_size=self.zero_sub_group_size(), offload_ratio=self.zero_partial_offload(), mpu=self.mpu, diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py new file mode 100644 index 000000000000..0ceb59c8ca73 --- /dev/null +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -0,0 +1,641 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.zero.partition_parameters import * + +import torch +import math +from deepspeed import comm as dist +from deepspeed.utils import logger +from deepspeed.ops.adam import ZenFlowSelectiveAdamW_stage3 +from deepspeed.runtime.utils import see_memory_usage +from typing import List +from deepspeed.accelerator import get_accelerator +from typing import TYPE_CHECKING +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process + +if TYPE_CHECKING: + from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 + +OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state' +INIT_OPTIMIZER_TIMER = 'init_optimizer_state' +OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state' +OPTIMIZER_STEP_TIMER = 'optimizer_step' + + +def configure_zenflow(optimizer_z3, zenflow_config): + + optimizer_z3.select_strategy = zenflow_config.select_strategy + if optimizer_z3.select_strategy == 'auto': + optimizer_z3.select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + optimizer_z3.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + optimizer_z3.select_interval = int(zenflow_config.select_interval) + + if isinstance(zenflow_config.update_interval, str): + optimizer_z3.auto_update = True + optimizer_z3.update_interval = 0 + else: + optimizer_z3.auto_update = False + optimizer_z3.update_interval = int(zenflow_config.update_interval) + + if optimizer_z3.select_strategy == 'epoch': + if zenflow_config.steps_per_epoch is not None: + optimizer_z3.select_interval = optimizer_z3.select_interval * zenflow_config.steps_per_epoch + else: + optimizer_z3.select_interval = 0 + + if not optimizer_z3.auto_update and optimizer_z3.select_interval != 0 and optimizer_z3.select_interval < optimizer_z3.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + optimizer_z3.topk_ratio = zenflow_config.topk_ratio + + optimizer_z3.param_id_grad_sum_buffer_offset = {} + + optimizer_z3.zf_stage3 = True + + if optimizer_z3.auto_update: + optimizer_z3.param_id_sum_buffer_offset = {} + optimizer_z3.auto_ratio = zenflow_config.auto_ratio + optimizer_z3.zenflow_need_update = [False, False] + optimizer_z3.zenflow_state = 0 + optimizer_z3.num_need_update = 0 + + +def _initialize_zenflow_stage3_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", + module, + zenflow_config: dict = None): + + optimizer_z3.zenflow = True if zenflow_config is not None else False + + if not optimizer_z3.zenflow: + return + + optimizer_z3.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc + + for p in module.parameters(): + p.data = p.data.t().contiguous() if len(p.shape) != 1 else p.data + + +def _initialize_zenflow_stage3_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", + zenflow_config: dict = None, + overlap_comm: bool = False): + + if not optimizer_z3.zenflow: + return + + optimizer_z3.micro_step = -1 + optimizer_z3.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + optimizer_z3.offload_selective_optimizer = zenflow_config.offload + optimizer_z3.zenflow_overlap_step = zenflow_config.overlap_step + + if optimizer_z3.offload_selective_optimizer: + assert overlap_comm, "offload selective optimizer should be used with overlap_comm" + + if optimizer_z3.zenflow_overlap_step: + optimizer_z3.process_optimizer_established = False + optimizer_z3.first_update_round_after_warmup = True + optimizer_z3.initialize_optimizer_states = lambda: initialize_optimizer_states(optimizer_z3) + optimizer_z3.step = lambda closure=None: step(optimizer_z3, closure) + optimizer_z3.zenflow_cpu_optimizer_overlap_step = lambda now_state, scaled_global_grad_norm: zenflow_cpu_optimizer_overlap_step( + optimizer_z3, now_state, scaled_global_grad_norm) + optimizer_z3.wait_last_update_and_copy = lambda timer_names: wait_last_update_and_copy( + optimizer_z3, timer_names) + optimizer_z3.partition_grads = lambda params_to_release, grad_partitions: partition_grads( + optimizer_z3, params_to_release, grad_partitions) + optimizer_z3.get_overlap_step_state = lambda: get_overlap_step_state(optimizer_z3) + optimizer_z3.start_optimizer_process = lambda: start_optimizer_process(optimizer_z3) + optimizer_z3.unscale_and_clip_grads = lambda sub_group_id, total_norm, now_state: unscale_and_clip_grads( + optimizer_z3, sub_group_id, total_norm, now_state) + + configure_zenflow(optimizer_z3, zenflow_config) + optimizer_z3.selective_optimizer = ZenFlowSelectiveAdamW_stage3([{ + k: v + for k, v in group.items() if k != "params" + } | { + "params": group["params"] + } for group in optimizer_z3.optimizer.param_groups], + offload=optimizer_z3.offload_selective_optimizer) + optimizer_z3.num_total_param = sum( + sum(1 for param in group["params"] if len(param.ds_shape) != 1) + for group in optimizer_z3.optimizer.param_groups) + + +def zenflow_cpu_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + return optimizer_z3.optimizer.step(step_id=optimizer_z3.micro_step + 1) + + +def _sync_selective_optimizer_lr(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + for group_selected, group in zip(optimizer_z3.selective_optimizer.param_groups, + optimizer_z3.optimizer.param_groups): + group_selected["lr"] = group["lr"] + + +def selective_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3.selective_optimizer.step() + + +def is_zenflow_select_boundary(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> bool: + return optimizer_z3.zenflow and (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) >= 0 and ( + (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) == 0 or + (optimizer_z3.select_interval != 0 and optimizer_z3.micro_step % optimizer_z3.select_interval == 0)) + + +def update_selected_channels(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", params_to_update, grad_partitions): + src_rk = dist.get_rank(optimizer_z3.dp_process_group) + total_rk = dist.get_world_size(optimizer_z3.dp_process_group) + + total_chunk_size = 0 + param_local_offset = [0 for _ in range(total_rk)] + + for param, grad_partition in zip(params_to_update, grad_partitions): + param_max_chunk_size = 0 + param_rk_offset = 0 + for rk in range(total_rk): + contains_real_data = param.partition_numel() * rk < param.ds_numel + if not contains_real_data: + param.grad = None + continue + + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row == 1: + continue + + partition_size = param.partition_numel() + start = partition_size * rk + end = min(start + partition_size, param.ds_numel) + + start_idx = math.ceil(start / num_row) + end_idx = end // num_row + num_cols = end_idx - start_idx + + if param.ds_id not in optimizer_z3.param_id_grad_sum_buffer_offset: + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = [] + + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id].append( + (param_local_offset[rk], num_cols, param_rk_offset)) + + param_max_chunk_size = max(param_max_chunk_size, num_cols) + param_rk_offset += num_cols + param_local_offset[rk] += num_cols + + total_chunk_size += param_max_chunk_size + + optimizer_z3.grad_sum_buffer = torch.zeros(total_chunk_size, dtype=optimizer_z3.dtype, device='cuda') + + for param, grad_partition in zip(params_to_update, grad_partitions): + contains_real_data = param.partition_numel() * src_rk < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + #ds_shape is the transposed shape, it should not be same as param.shape + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + continue + + partition_size = param.partition_numel() + start = partition_size * src_rk + end = min(start + partition_size, param.ds_numel) + + start_idx = math.ceil(start / num_row) + end_idx = end // num_row + + num_elements = (end_idx - start_idx) * num_row + + param.complete_column_offset = start_idx * num_row - start + param.complete_numel = (end_idx - start_idx) * num_row + + sum_per_column = grad_partition.narrow(0, param.complete_column_offset, num_elements) + sum_per_column = sum_per_column.view(end_idx - start_idx, num_row) + sum_array = sum_per_column.abs().sum(dim=1) + + offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk] + optimizer_z3.grad_sum_buffer.narrow(0, offset, length).copy_(sum_array) + + gathered_chunks = [torch.zeros_like(optimizer_z3.grad_sum_buffer) for _ in range(total_rk)] + dist.all_gather(gathered_chunks, optimizer_z3.grad_sum_buffer, group=optimizer_z3.dp_process_group) + + for param in params_to_update: + + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + continue + + param_column_sum = [] + for rk in range(total_rk): + offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][rk] + param_column_sum.append(gathered_chunks[rk].narrow(0, offset, length)) + global_param_column_sum = torch.cat(param_column_sum, dim=0) + + num_select = max(1, int(global_param_column_sum.numel() * optimizer_z3.topk_ratio)) + _, global_topk_indices = torch.topk(global_param_column_sum, num_select, largest=True) + + _, length, rk_offset = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk] + local_indices = [(idx.item() - rk_offset) for idx in global_topk_indices + if rk_offset <= idx < rk_offset + length] + param.selected_indices = torch.tensor(local_indices, device='cuda') + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = [] + + optimizer_z3.grad_sum_buffer = None + + +def _process_selected_fp32_groups_grad(optimizer_z3, params_to_update, grad_partitions): + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer = torch.zeros(optimizer_z3.num_total_param, dtype=optimizer_z3.dtype, device='cuda') + optimizer_z3.critic_sum_buffer = torch.zeros(optimizer_z3.num_total_param, + dtype=optimizer_z3.dtype, + device='cuda') + curr_buffer_idx = 0 + + for param, grad_partition in zip(params_to_update, grad_partitions): + + rk = dist.get_rank(optimizer_z3.dp_process_group) + + contains_real_data = param.partition_numel() * rk < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + #ds_shape is the transposed shape, it should not be same as param.shape + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + param.selected_grad = grad_partition.clone().detach() + else: + grad_2d = grad_partition.narrow(0, param.complete_column_offset, + param.complete_numel).view(param.complete_numel // num_row, num_row) + param.selected_grad = grad_2d[param.selected_indices, :].clone().detach() + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer[curr_buffer_idx] = grad_partition.abs().sum() + optimizer_z3.critic_sum_buffer[curr_buffer_idx] = param.selected_grad.abs().sum() + curr_buffer_idx += 1 + + if optimizer_z3.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'): + buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=optimizer_z3.device) + param.exp_avg_cpu_data = get_accelerator().pin_memory( + buffer) if optimizer_z3.offload_optimizer_pin_memory else buffer + param.exp_avg_sq_cpu_data = get_accelerator().pin_memory( + buffer.clone()) if optimizer_z3.offload_optimizer_pin_memory else buffer.clone() + + if optimizer_z3.auto_update: + total_rk = dist.get_world_size(optimizer_z3.dp_process_group) + sum_gather_list = [torch.zeros_like(optimizer_z3.sum_buffer) for _ in range(total_rk)] + critic_gather_list = [torch.zeros_like(optimizer_z3.critic_sum_buffer) for _ in range(total_rk)] + curr_buffer_idx = 0 + + dist.all_gather(sum_gather_list, optimizer_z3.sum_buffer, group=optimizer_z3.dp_process_group) + dist.all_gather(critic_gather_list, optimizer_z3.critic_sum_buffer, group=optimizer_z3.dp_process_group) + + for param in params_to_update: + if len(param.ds_shape) == 1: + continue + + if not hasattr(param, 'non_critic_sum'): + param.non_critic_sum = 0 + if not hasattr(param, 'avg_critic_sum'): + param.avg_critic_sum = 0 + + grad_total_sum = sum(sum_gather_list[rk][curr_buffer_idx] for rk in range(total_rk)) + grad_critic_sum = sum(critic_gather_list[rk][curr_buffer_idx] for rk in range(total_rk)) + + param.avg_critic_sum = (param.avg_critic_sum * (optimizer_z3.update_interval - 1) + + grad_critic_sum) / optimizer_z3.update_interval / (optimizer_z3.topk_ratio * 10) + param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - optimizer_z3.topk_ratio) * 10) + if param.non_critic_sum >= param.avg_critic_sum: + optimizer_z3.num_need_update += 1 + if optimizer_z3.num_need_update >= int(optimizer_z3.auto_ratio * optimizer_z3.num_total_param): + optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = True + + curr_buffer_idx += 1 + + if not optimizer_z3.is_gradient_accumulation_boundary: + optimizer_z3.selective_optimizer.group_step(params_to_update) + else: + optimizer_z3.selective_optimizer.temp_copy_param(params_to_update) + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer = None + optimizer_z3.critic_sum_buffer = None + + +def sync_fp32_param_from_gpu(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + + if optimizer_z3.micro_step == 0: + return + + for fp16_partitions, fp32_partition in zip(optimizer_z3.fp16_partitioned_groups_flat, + optimizer_z3.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + +def zenflow_backward_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3.micro_step += 1 + if optimizer_z3.auto_update: + optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = False + optimizer_z3.num_need_update = 0 + if optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state ^ 1]: + optimizer_z3.update_interval = 0 + for group in optimizer_z3.fp16_groups: + for p in group: + p.non_critic_sum = 0 + optimizer_z3.update_interval += 1 + if optimizer_z3.is_zenflow_select_boundary(): + sync_fp32_param_from_gpu(optimizer_z3) + optimizer_z3.selective_optimizer.clear_selected_mv() + + +def zenflow_backward_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3._partition_all_parameters() + + +def log_selective_optimizer_timers(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + pass + + +def initialize_optimizer_states(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + num_subgroups = len(optimizer_z3.fp16_groups) + + largest_numel = max([sum([p.ds_numel for p in psg]) for psg in optimizer_z3.fp16_partitioned_groups]) + gradient_dtype = optimizer_z3.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=optimizer_z3.device) + + timer_names = set() + + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + is_adagrad = isinstance(optimizer_z3.optimizer, torch.optim.Adagrad) + + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.init_timers() + + timer_names.add(INIT_OPTIMIZER_TIMER) + optimizer_z3.timers(INIT_OPTIMIZER_TIMER).start() + + for i, group in enumerate(optimizer_z3.fp16_groups): + swappable_optimizer_subgroup = optimizer_z3._swappable_optimizer_subgroup(i) + swappable_param_subgroup = optimizer_z3.fp16_partitioned_groups_flat[i] is None + + num_elements = int(optimizer_z3.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + optimizer_z3._optimizer_states_and_gradient_swap_in(i, timer_names) + + if optimizer_z3.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=optimizer_z3.device) + if optimizer_z3.offload_optimizer_pin_memory: + subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) + + optimizer_z3.fp32_partitioned_groups_flat[i].grad = None + optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad = [ + subgroup_gradient_buffer.to(optimizer_z3.subgroup_to_device[i]), + subgroup_gradient_buffer.clone().to(optimizer_z3.subgroup_to_device[i]) + ] + else: + optimizer_z3.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) + + if swappable_param_subgroup: + optimizer_z3._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + optimizer_z3._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + # Initialize the optimizer states with the flattened fp32 partition. + if is_adagrad: + optimizer_z3.optimizer = torch.optim.Adagrad(optimizer_z3.fp32_partitioned_groups_flat, + **optimizer_z3.optimizer.defaults) + + optimizer_z3.timers(INIT_OPTIMIZER_TIMER).stop() + optimizer_z3.timers.log(timer_names) + + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.log_timers() + + if not optimizer_z3.offload_optimizer: + for group in optimizer_z3.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + +def get_overlap_step_state(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> int: + if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds: + return optimizer_z3.micro_step & 1 + else: + if not optimizer_z3.auto_update: + return (optimizer_z3.micro_step // optimizer_z3.update_interval) & 1 + else: + return optimizer_z3.zenflow_state + + +@instrument_w_nvtx +def partition_grads(optimizer_z3, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + offload_fp32_gradients = {} + offload_fp32_offsets = {} + buffers = [] + for param, grad_partition in zip(params_to_release, grad_partitions): + + contains_real_data = param.partition_numel() * dist.get_rank(optimizer_z3.dp_process_group) < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + # move or accumulate gradient partition to target buffer + param_id_to_grad_partition = getattr(optimizer_z3, + f"_{optimizer_z3.__class__.__name__}__param_id_to_grad_partition") + grad_buffer = param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) + buffers.append(grad_buffer) + if optimizer_z3.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif get_accelerator().on_accelerator(grad_buffer): + grad_buffer.add_(grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(grad_buffer.shape)) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + cuda_grad_buffer.add_( + grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + # offload the gradient partition if applicable + if optimizer_z3.offload_optimizer: + i, dest_offset, _ = optimizer_z3.grad_position[optimizer_z3.get_param_id(param)] + now_state = optimizer_z3.get_overlap_step_state() + + if optimizer_z3.is_gradient_accumulation_boundary: + optimizer_z3.norm_for_param_grads[optimizer_z3.get_param_id( + param)] = optimizer_z3._constant_buffered_norm2(grad_buffer) + + if optimizer_z3._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad[now_state].narrow( + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.float()) + + # free the gradient + if not get_accelerator().is_synchronized_device(): + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) + param.grad = None + + if optimizer_z3.offload_optimizer and optimizer_z3.swap_optimizer: + for i in offload_fp32_gradients.keys(): + optimizer_z3.optimizer_swapper.swap_out_gradients(parameter=optimizer_z3.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + return buffers + + +@instrument_w_nvtx +def unscale_and_clip_grads(self, sub_group_id, total_norm, now_state): + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale + + self.fp32_partitioned_groups_flat[sub_group_id].overlap_grad[now_state].mul_(1. / combined_scale) + + +def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_grad_norm): + + if not optimizer_z3.process_optimizer_established: + optimizer_z3.start_optimizer_process() + + group_infos = [] + for group_no, group in enumerate(optimizer_z3.fp16_groups): + optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state) + param_group_id = optimizer_z3.sub_group_to_group_id[group_no] + + group_info = { + "lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"], + "betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"], + "eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"], + "weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"], + "bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"], + } + + group_infos.append(group_info) + + optimizer_z3.parent_conn.send({ + "type": "step", + "now_state": now_state, + "micro_step": optimizer_z3.micro_step, + "group_infos": group_infos + }) + + +def wait_last_update_and_copy(optimizer_z3, timer_names): + + if not hasattr(optimizer_z3, 'parent_conn'): + return + + if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup: + optimizer_z3.first_update_round_after_warmup = False + return + + msg = optimizer_z3.parent_conn.recv() + assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + + for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): + if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: + optimizer_z3.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + optimizer_z3.fp32_partitioned_groups_flat[sub_group_id].stale_param.data) + + #unflatten fp16 parameter subgroup + optimizer_z3._unflatten_partitioned_parameters(sub_group_id) + else: + optimizer_z3._partitioned_params_swap_out(sub_group_id) + + optimizer_z3._post_step(timer_names) + + # warn user about caching allocator flushes + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > optimizer_z3.n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "get_accelerator().empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - optimizer_z3.n_caching_allocator_flushes) + optimizer_z3.n_caching_allocator_flushes = alloc_retries + + +@instrument_w_nvtx +def step(optimizer_z3, closure=None): + """ + Not supporting closure. + """ + optimizer_z3._pre_step() + optimizer_z3._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if optimizer_z3._overflow_check_and_loss_scale_update(): + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.log_timers() + return + + norm_groups = optimizer_z3._get_norm_groups() + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) + + # Stash unscaled gradient norm + optimizer_z3._global_grad_norm = scaled_global_grad_norm / optimizer_z3.loss_scale + + if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds: + optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm) + + timer_names = set() + + timer_names.add(OPTIMIZER_STEP_TIMER) + + optimizer_z3.wait_last_update_and_copy(timer_names) + + if optimizer_z3.micro_step >= optimizer_z3.full_warm_up_rounds: + optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm) + + return diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py index 5d44bf6b8a41..daec04772158 100644 --- a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -3,14 +3,11 @@ # DeepSpeed Team -import os -import math -import psutil import torch from deepspeed import comm as dist -import torch.multiprocessing as mp from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process from deepspeed.runtime.utils import (see_memory_usage) from deepspeed.ops.adam import ZenFlowSelectiveAdamW @@ -97,6 +94,8 @@ def __init__(self, self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds self.offload_selective_optimizer = zenflow_config.offload self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc + self.start_optimizer_process = lambda: start_optimizer_process(self) + self.zf_stage3 = False if self.offload_selective_optimizer: assert overlap_comm, "offload selective optimizer should be used with overlap_comm" @@ -636,64 +635,10 @@ def zenflow_cpu_optimizer_step(self, group_no): self.optimizer.step(step_id=self.micro_step + 1) -def disable_accelerator(): - accelerator = get_accelerator() - accelerator.is_available = lambda: False - accelerator.device_count = lambda: 0 - accelerator.current_device = lambda: -1 - # Optionally mark it as initialized if needed - if hasattr(accelerator, "_initialized"): - accelerator._initialized = True - - -def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map, - shared_stale_param_map, zf_affinity): - disable_accelerator() - - current_process = psutil.Process() - current_process.cpu_affinity(zf_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) - - from deepspeed.ops.adam import ZenFlowCPUAdam - optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) - - pipe.send({"type": "ready"}) - - # TODO: replace this with rpc - - while True: - cmd = pipe.recv() - if cmd["type"] == "step": - now_state = cmd["now_state"] - micro_step = cmd["micro_step"] - group_infos = cmd["group_infos"] - - for group_no, group_info in enumerate(group_infos): - original_param_groups = optimizer.param_groups - optimizer.param_groups = [original_param_groups[group_no]] - group = optimizer.param_groups[0] - - for param_idx, param in enumerate(group["params"]): - key = (group_no, param_idx) - if key in shared_overlap_grad_map: - param.overlap_grad = shared_overlap_grad_map[key] - if key in shared_stale_param_map: - param.stale_param = shared_stale_param_map[key] - - optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) - - optimizer.param_groups = original_param_groups - - pipe.send({"type": "done"}) - elif cmd["type"] == "exit": - break - - class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer): def __init__(self, *args, **kwargs): super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs) - self.process_pool = mp.Pool(1) self.process_optimizer_established = False self.first_update_round_after_warmup = True @@ -759,85 +704,6 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): dest_tensor.copy_(src_tensor, non_blocking=True) param.grad = None #offload only - # check if all tensors in the list are equal to each other - def all_tensors_equal(self, tensor_list): - first_tensor = tensor_list[0] - for tensor in tensor_list[1:]: - if not torch.equal(first_tensor, tensor): - return False - return True - - def start_optimizer_process(self): - from multiprocessing import Pipe, get_context, Manager - - ctx = get_context("spawn") - self.parent_conn, self.child_conn = Pipe() - - manager = Manager() - self.shared_overlap_grad_map = manager.dict() - self.shared_stale_param_map = manager.dict() - - for group_no, group in enumerate(self.optimizer.param_groups): - for param_idx, param in enumerate(group['params']): - param.data.share_memory_() - if not hasattr(param, 'stale_param'): - param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) - param.stale_param.data.share_memory_() - key = (group_no, param_idx) - self.shared_stale_param_map[key] = param.stale_param - if param.overlap_grad is not None: - param.overlap_grad[0].data.share_memory_() - param.overlap_grad[1].data.share_memory_() - key = (group_no, param_idx) - self.shared_overlap_grad_map[key] = param.overlap_grad - - param_groups_data = self.optimizer.param_groups - curr_rank = dist.get_rank() - total_rank = dist.get_world_size() - - current_process = psutil.Process() - current_affinity = current_process.cpu_affinity() - all_affinities = [ - torch.zeros(len(current_affinity), - dtype=type(current_affinity[0]), - device=get_accelerator().current_device_name()) for _ in range(total_rank) - ] - dist.all_gather( - all_affinities, - torch.tensor(current_affinity, - dtype=type(current_affinity[0]), - device=get_accelerator().current_device_name())) - # When affinity across all ranks are the same, the workers are not binded. Do a soft bind here - if self.all_tensors_equal(all_affinities): - num_phy_cores = psutil.cpu_count(logical=False) - available_phy_cores = [i for i in current_affinity if i < num_phy_cores] - num_available_phy_cores = len(available_phy_cores) - my_rank = curr_rank - my_size = total_rank - cores_per_rank = num_available_phy_cores // my_size - current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank] - pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity)) - if pt_num_cores > 0 and pt_num_cores < len(current_affinity): - zf_affinity = current_affinity[pt_num_cores:] - pt_affinity = current_affinity[:pt_num_cores] - else: - zf_affinity = current_affinity - pt_affinity = current_affinity - self.process = ctx.Process( - target=zenflow_optimizer_process, - args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map, - self.shared_stale_param_map, zf_affinity), - ) - self.process.daemon = True - self.process.start() - current_process.cpu_affinity(pt_affinity) - os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) - - msg = self.parent_conn.recv() - assert msg["type"] == "ready", "Optimizer process did not initialize correctly." - - self.process_optimizer_established = True - def wait_last_update_and_copy(self): if not hasattr(self, 'parent_conn'): diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py index 4d2fcaaa4b86..f238b3626506 100644 --- a/deepspeed/runtime/zenflow/zenflow_utils.py +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -3,7 +3,12 @@ # DeepSpeed Team +import os +import math import torch +import psutil +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator def _flatten_dense_tensors(tensors): @@ -40,3 +45,147 @@ def _unflatten_dense_tensors(flat, tensors): transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] + + +def disable_accelerator(): + accelerator = get_accelerator() + accelerator.is_available = lambda: False + accelerator.device_count = lambda: 0 + accelerator.current_device = lambda: -1 + # Optionally mark it as initialized if needed + if hasattr(accelerator, "_initialized"): + accelerator._initialized = True + + +def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity): + disable_accelerator() + + current_process = psutil.Process() + current_process.cpu_affinity(zf_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) + + from deepspeed.ops.adam import ZenFlowCPUAdam + optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) + + pipe.send({"type": "ready"}) + + # TODO: replace this with rpc + + while True: + cmd = pipe.recv() + if cmd["type"] == "step": + now_state = cmd["now_state"] + micro_step = cmd["micro_step"] + group_infos = cmd["group_infos"] + + for group_no, group_info in enumerate(group_infos): + original_param_groups = optimizer.param_groups + optimizer.param_groups = [original_param_groups[group_no]] + group = optimizer.param_groups[0] + + for param_idx, param in enumerate(group["params"]): + key = (group_no, param_idx) + if key in shared_overlap_grad_map: + param.overlap_grad = shared_overlap_grad_map[key] + if key in shared_stale_param_map: + param.stale_param = shared_stale_param_map[key] + + optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) + + optimizer.param_groups = original_param_groups + + pipe.send({"type": "done"}) + elif cmd["type"] == "exit": + break + + +def all_tensors_equal(tensor_list): + first_tensor = tensor_list[0] + for tensor in tensor_list[1:]: + if not torch.equal(first_tensor, tensor): + return False + return True + + +def start_optimizer_process(zf_optimizer): + from multiprocessing import Pipe, get_context, Manager + + ctx = get_context("spawn") + zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe() + + manager = Manager() + zf_optimizer.shared_overlap_grad_map = manager.dict() + zf_optimizer.shared_stale_param_map = manager.dict() + + if zf_optimizer.zf_stage3: + params_iter = [((group_no, 0), param) + for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)] + else: + params_iter = [((group_no, param_idx), param) + for group_no, group in enumerate(zf_optimizer.optimizer.param_groups) + for param_idx, param in enumerate(group["params"])] + + for key, param in params_iter: + param.data.share_memory_() + + if not hasattr(param, "stale_param"): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + param.stale_param.data.share_memory_() + zf_optimizer.shared_stale_param_map[key] = param.stale_param + + if getattr(param, "overlap_grad", None) is not None: + param.overlap_grad[0].data.share_memory_() + param.overlap_grad[1].data.share_memory_() + zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad + + param_groups_data = ([{ + "params": [param] + } for param in zf_optimizer.fp32_partitioned_groups_flat] + if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups) + + curr_rank = dist.get_rank() + total_rank = dist.get_world_size() + + current_process = psutil.Process() + current_affinity = current_process.cpu_affinity() + all_affinities = [ + torch.zeros(len(current_affinity), + dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name()) for _ in range(total_rank) + ] + dist.all_gather( + all_affinities, + torch.tensor(current_affinity, dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name())) + # When affinity across all ranks are the same, the workers are not binded. Do a soft bind here + if all_tensors_equal(all_affinities): + num_phy_cores = psutil.cpu_count(logical=False) + available_phy_cores = [i for i in current_affinity if i < num_phy_cores] + num_available_phy_cores = len(available_phy_cores) + my_rank = curr_rank + my_size = total_rank + cores_per_rank = num_available_phy_cores // my_size + current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank] + pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity)) + if pt_num_cores > 0 and pt_num_cores < len(current_affinity): + zf_affinity = current_affinity[pt_num_cores:] + pt_affinity = current_affinity[:pt_num_cores] + else: + zf_affinity = current_affinity + pt_affinity = current_affinity + + zf_optimizer.process = ctx.Process( + target=zenflow_optimizer_process, + args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map, + zf_optimizer.shared_stale_param_map, zf_affinity), + ) + zf_optimizer.process.daemon = True + zf_optimizer.process.start() + + current_process.cpu_affinity(pt_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) + + msg = zf_optimizer.parent_conn.recv() + assert msg["type"] == "ready", "Optimizer process did not initialize correctly." + + zf_optimizer.process_optimizer_established = True diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 000d0ebde1c7..cff24feb2dab 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -93,6 +93,7 @@ def __init__( module, timers, ds_config, + zenflow=False, overlap_comm=True, prefetch_bucket_size=50000000, max_reuse_distance=1000000000, @@ -115,6 +116,7 @@ def __init__( self.module = module self.timers = timers + self.zenflow = zenflow self.dtype = list(module.parameters())[0].dtype self.dp_process_group = dp_process_group self.offload_device = None @@ -472,6 +474,11 @@ def pre_sub_module_forward_function(self, sub_module): param_coordinator.record_module(sub_module) param_coordinator.fetch_sub_module(sub_module, forward=True) + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) @torch.no_grad() @@ -480,6 +487,11 @@ def post_sub_module_forward_function(self, sub_module): f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release", force=False) + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data + param_coordinator = self.get_param_coordinator() param_coordinator.release_sub_module(sub_module, forward=True) @@ -496,6 +508,11 @@ def pre_sub_module_backward_function(self, sub_module): param_coordinator.record_module(sub_module) param_coordinator.fetch_sub_module(sub_module, forward=False) + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data + @torch.no_grad() def post_sub_module_backward_function(self, sub_module): # assert sub_module.training, "backward pass is invalid for module in evaluation mode" @@ -503,6 +520,11 @@ def post_sub_module_backward_function(self, sub_module): f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release", force=False) + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data + self.get_param_coordinator().release_sub_module(sub_module, forward=False) see_memory_usage( diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 76790223a49b..82c406b016e2 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -26,6 +26,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3 from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states from deepspeed.ops.adam import DeepSpeedCPUAdam @@ -160,6 +161,7 @@ def __init__( overlap_comm=False, offload_optimizer_config=None, offload_param_config=None, + zenflow_config=None, sub_group_size=1000000000000, offload_ratio=0.0, mpu=None, @@ -226,6 +228,9 @@ def __init__( self.partial_offload = offload_ratio self.enable_sanity_checks = enable_sanity_checks + self.create_zenflow_hooks() + self._initialize_zenflow_stage3_prologue(module, zenflow_config) + #num of ranks in a ZeRO param partitioning group self.zero_hpz_partition_size = zero_hpz_partition_size @@ -241,6 +246,7 @@ def __init__( module=module, timers=timers, ds_config=ds_config, + zenflow=self.zenflow, overlap_comm=overlap_comm, prefetch_bucket_size=prefetch_bucket_size, max_reuse_distance=max_reuse_distance, @@ -276,6 +282,8 @@ def __init__( for i in range(1, len(self.optimizer.param_groups)): self.backup_optimizer.add_param_group(self.optimizer.param_groups[i]) + self._initialize_zenflow_stage3_epilogue(zenflow_config, overlap_comm) + self.module = module self.elastic_checkpoint = elastic_checkpoint @@ -476,11 +484,32 @@ def destroy(self): print_rank_0("Removed grad acc hooks", force=False) self.ipg_buckets.clear() + def create_zenflow_hooks(self): + from functools import partial + hook_names = [ + "_initialize_zenflow_stage3_prologue", + "_initialize_zenflow_stage3_epilogue", + "zenflow_cpu_optimizer_step", + "_sync_selective_optimizer_lr", + "selective_optimizer_step", + "is_zenflow_select_boundary", + "update_selected_channels", + "_process_selected_fp32_groups_grad", + "zenflow_backward_prologue", + "zenflow_backward_epilogue", + "log_selective_optimizer_timers", + ] + + for name in hook_names: + fn = getattr(zf_engine_stage3, name) + setattr(self, name, partial(fn, self)) + def initialize_ds_offload( self, module, timers, ds_config, + zenflow, overlap_comm, prefetch_bucket_size, max_reuse_distance, @@ -499,6 +528,7 @@ def initialize_ds_offload( return DeepSpeedZeRoOffload(module=module, timers=timers, ds_config=ds_config, + zenflow=zenflow, overlap_comm=overlap_comm, prefetch_bucket_size=prefetch_bucket_size, max_reuse_distance=max_reuse_distance, @@ -738,6 +768,10 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): self.fp16_groups.append(sub_group) self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group]) + if self.zenflow: + for param in sub_group: + param.group_id = param_group_idx + # record sub group -> group mapping self.sub_group_to_group_id[sub_group_idx] = param_group_idx @@ -997,7 +1031,10 @@ def step_with_gradscaler(optimizer): self.torch_autocast_gradscaler.step(optimizer) self.torch_autocast_gradscaler.update() else: - optimizer.step() + if not self.zenflow: + optimizer.step() + else: + self.zenflow_cpu_optimizer_step() if self.offload_optimizer: cur_device = self.subgroup_to_device[sub_group_id] @@ -1267,8 +1304,14 @@ def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: # move the gradient to a contiguous buffer with get_accelerator().stream(self.reduce_and_partition_stream): # move the parameter's gradient to the contiguous flat buffer - new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad) - new_grad_tensor.copy_(param.grad, non_blocking=True) + if self.zenflow and len(param.ds_shape) != 1: + transposed_shape = param.grad.t().shape + new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, + param.grad.numel()).view(transposed_shape) + new_grad_tensor.copy_(param.grad.t().contiguous(), non_blocking=True) + else: + new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) if not get_accelerator().is_synchronized_device(): param.grad.record_stream(get_accelerator().current_stream()) param.grad.data = new_grad_tensor @@ -1308,6 +1351,12 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) params_in_bucket.sort(key=lambda p: p.ds_id) grad_partitions = self.__avg_scatter_grads(params_in_bucket, communication_data_type) + if self.is_zenflow_select_boundary(): + self.update_selected_channels(params_in_bucket, grad_partitions) + + if self.zenflow and self.micro_step >= self.full_warm_up_rounds: + self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions) + self.partition_grads(params_in_bucket, grad_partitions) params_in_bucket.clear() @@ -2272,6 +2321,9 @@ def backward(self, loss, retain_graph=False): if self.swap_optimizer: self.optimizer_swapper.pre_backward() + if self.zenflow: + self.zenflow_backward_prologue() + see_memory_usage("Before backward", force=False) if self.custom_loss_scaler: @@ -2282,6 +2334,9 @@ def backward(self, loss, retain_graph=False): else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if self.zenflow: + self.zenflow_backward_epilogue() + if self.swap_optimizer: self.optimizer_swapper.post_backward() diff --git a/tests/unit/ops/adam/test_zf_torch_adam.py b/tests/unit/ops/adam/test_zf_torch_adam.py index c7163ffe2f09..faa44fe9f853 100644 --- a/tests/unit/ops/adam/test_zf_torch_adam.py +++ b/tests/unit/ops/adam/test_zf_torch_adam.py @@ -3,40 +3,64 @@ # DeepSpeed Team +import pytest import torch import numpy as np from torch.nn import Parameter -from deepspeed.ops.adam import ZenFlowSelectiveAdamW +from deepspeed.ops.adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 -def make_param(shape, selected_indices=None): +def make_param(Opt, shape, selected_indices=None): param = Parameter(torch.randn(*shape)) + + if Opt is ZenFlowSelectiveAdamW_stage3: + if param.dim() == 2: + param.ds_shape = (param.shape[1], param.shape[0]) + param.ds_tensor = param.clone().T.contiguous().view(-1) + else: + param.ds_shape = tuple(param.shape) + param.ds_tensor = param.clone() + + param.complete_column_offset = 0 + param.complete_numel = param.numel() + param.group_id = 0 + if selected_indices is not None: param.selected_indices = selected_indices - param.selected_grad = torch.randn(param.shape[0], len(selected_indices)) - param.temp_selected_param = param.data[:, selected_indices].clone() + if param.dim() == 2: + param.selected_grad = torch.randn( + param.shape[0], len(selected_indices)) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn( + len(selected_indices), param.ds_shape[1]) + param.temp_selected_param = param.data[:, selected_indices].clone( + ) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[selected_indices, :].clone() + else: + param.selected_grad = torch.randn_like(param.data) + param.temp_selected_param = param.data.clone() return param -def test_init_methods(): - opt1 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_init_methods(Opt): + opt1 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False) assert opt1.step == opt1._step_without_offload - assert opt1.group_step == opt1._group_step_without_offload - opt2 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True) + opt2 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True) assert opt2.step == opt2._step_with_offload - assert opt2.group_step == opt2._group_step_with_offload -def test_step_without_offload(): - param = make_param((4, 6), torch.tensor([1, 3, 4])) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_step_without_offload(Opt): + param = make_param(Opt, (4, 6), torch.tensor([1, 3, 4])) param.requires_grad_(True) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) - - old_selected = param.data[:, param.selected_indices].clone() + opt = Opt([param], lr=1e-3, offload=False) + old_selected = param.data[:, param.selected_indices].clone( + ) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[param.selected_indices, :].clone() opt.step() - - new_selected = param.data[:, param.selected_indices] + new_selected = param.data[:, param. + selected_indices] if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[param.selected_indices, :] diff_norm = (old_selected - new_selected).abs().sum().item() assert diff_norm > 1e-5, "param was not updated" @@ -44,9 +68,10 @@ def test_step_without_offload(): assert param.selected_grad is None -def test_step_with_offload_bucket_flush(): - param1 = make_param((2, 4), torch.tensor([1, 2])) - param2 = make_param((2, 4), torch.tensor([0, 3])) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_step_with_offload_bucket_flush(Opt): + param1 = make_param(Opt, (2, 4), torch.tensor([1, 2])) + param2 = make_param(Opt, (2, 4), torch.tensor([0, 3])) param1.exp_avg = torch.zeros_like(param1.temp_selected_param) param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param) @@ -58,15 +83,16 @@ def test_step_with_offload_bucket_flush(): param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu() param2.exp_avg_sq_cpu_data = param2.exp_avg_sq.clone().cpu() - opt = ZenFlowSelectiveAdamW([param1, param2], lr=1e-3, offload=True, bucket_size=1) + opt = Opt([param1, param2], lr=1e-3, offload=True, bucket_size=1) opt.step() assert param1.temp_selected_param is None assert param2.temp_selected_param is None -def test_clear_selected_mv(): - param = make_param((2, 4), torch.tensor([0, 2])) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_clear_selected_mv(Opt): + param = make_param(Opt, (2, 4), torch.tensor([0, 2])) + opt = Opt([param], lr=1e-3, offload=False) opt.step() state = opt.state[param] assert "exp_avg" in state @@ -74,17 +100,19 @@ def test_clear_selected_mv(): assert state["exp_avg"].abs().sum() == 0 -def test_group_step_without_offload(): - param = make_param((2, 6), torch.tensor([0, 1, 3])) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) - group_to_paramlist = {0: [param]} - opt._group_step_without_offload(group_to_paramlist) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_group_step_without_offload(Opt): + param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3])) + opt = Opt([param], lr=1e-3, offload=False) + group_to_paramlist = {0: [param]} if not Opt is ZenFlowSelectiveAdamW_stage3 else [param] + opt.group_step(group_to_paramlist) assert param.selected_grad is None -def test_group_step_with_offload(): - param = make_param((2, 6), torch.tensor([0, 1, 3])) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=True) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_group_step_with_offload(Opt): + param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3])) + opt = Opt([param], lr=1e-3, offload=True) state = opt.state.setdefault(param, {}) state["step"] = torch.zeros((), dtype=param.dtype, device=param.device) @@ -93,33 +121,30 @@ def test_group_step_with_offload(): param.exp_avg_cpu_data = param.exp_avg.clone().cpu() param.exp_avg_sq_cpu_data = param.exp_avg_sq.clone().cpu() - group_to_paramlist = {0: [param]} - opt._group_step_with_offload(group_to_paramlist) + group_to_paramlist = {0: [param]} if Opt is not ZenFlowSelectiveAdamW_stage3 else [param] + opt.group_step(group_to_paramlist) assert param.selected_grad is None -def test_1d_param_support(): - param = Parameter(torch.randn(10)) - param.selected_grad = torch.randn(10) - param.temp_selected_param = param.data.clone() - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_1d_param_support(Opt): + param = make_param(Opt, (10, ), torch.arange(10)) + opt = Opt([param], lr=1e-3, offload=False) opt.step() assert param.temp_selected_param is None assert param.selected_grad is None -def test_state_increment(): - param = torch.nn.Parameter(torch.randn(2, 4)) - param.selected_indices = torch.arange(4) - param.selected_grad = torch.randn(2, 4) - param.temp_selected_param = param.data.clone() +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_state_increment(Opt): + param = make_param(Opt, (2, 4), torch.arange(4)) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + opt = Opt([param], lr=1e-3, offload=False) opt.step() step1 = opt.state[param]['step'].item() - param.selected_grad = torch.randn(2, 4) - param.temp_selected_param = param.data.clone() + param.selected_grad = torch.randn(2, 4) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2) + param.temp_selected_param = param.data.clone() if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2) param.selected_indices = torch.arange(4) opt.step() @@ -134,22 +159,29 @@ def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4): for _ in range(10): grad = torch.randn_like(param) param.selected_indices = torch.arange(param.shape[1]) - param.selected_grad = grad - param.temp_selected_param = param.data.clone() + param.selected_grad = grad if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3) else grad.T + param.temp_selected_param = param.data.clone() if not isinstance( + zenflow_opt, ZenFlowSelectiveAdamW_stage3) else param.ds_tensor.view(param.ds_shape).clone() torch_param.grad = grad.clone() zenflow_opt.step() torch_opt.step() - np.testing.assert_allclose(torch_param.data.cpu().numpy(), - param.data.cpu().numpy(), - atol=atol, - err_msg="Mismatch with torch.AdamW") - - -def test_against_torch_adamw(): - param = torch.nn.Parameter(torch.randn(2, 4)) - param.selected_indices = torch.arange(4) - opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False) + if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3): + np.testing.assert_allclose(torch_param.data.cpu().numpy(), + param.data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW") + else: + np.testing.assert_allclose(torch_param.data.cpu().numpy(), + param.ds_tensor.view(param.ds_shape).T.clone().data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW") + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_against_torch_adamw(Opt): + param = make_param(Opt, (2, 4), torch.arange(4)) + opt = Opt([param], lr=1e-3, offload=False) _compare_with_torch_adamw(param, opt) diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py index 3294902bef67..7adcdb784972 100644 --- a/tests/unit/runtime/zenflow/test_zf.py +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -74,7 +74,7 @@ def run_training_distributed(self, config_dict): model.destroy() -@pytest.mark.parametrize("stage", [1, 2]) +@pytest.mark.parametrize("stage", [1, 2, 3]) @pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) @pytest.mark.parametrize("offload_selective_optimizer", [True, False]) @pytest.mark.parametrize("select_strategy,select_interval,update_interval", [ @@ -93,7 +93,7 @@ def test_zenflow_single_gpu(self, stage, offload_selective_optimizer, select_str tester.run_training_distributed(config_dict) -@pytest.mark.parametrize("stage", [1, 2]) +@pytest.mark.parametrize("stage", [1, 2, 3]) @pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) @pytest.mark.parametrize("offload_selective_optimizer", [True, False]) @pytest.mark.parametrize("select_strategy,select_interval,update_interval", [