diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index cd36f7a2d..6528c320c 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -8,6 +8,8 @@ import uvloop import yaml +from areal.utils.pkg_version import is_version_less + uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init @@ -162,13 +164,31 @@ class OptimizerConfig: type: str = field( default="adam", - metadata={"help": "Optimizer type", "choices": ["adam"]}, + metadata={ + "help": "Optimizer type. Adam_bf16 currently only supported FSDP Engine.", + "choices": ["adam", "sgd", "adam_bf16"], + }, ) lr: float = field(default=2e-5, metadata={"help": "Learning rate"}) weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"}) - beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"}) - beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"}) - eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"}) + beta1: float = field( + default=0.9, + metadata={ + "help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16" + }, + ) + beta2: float = field( + default=0.95, + metadata={ + "help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16" + }, + ) + eps: float = field( + default=1e-5, + metadata={ + "help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16" + }, + ) min_lr_ratio: float = field( default=0.0, metadata={ @@ -632,6 +652,8 @@ def build_cmd( # convert to flags flags = [] for k, v in args.items(): + if is_version_less("sglang", "0.4.10.post2") and "max_loaded_loras" in k: + continue if v is None or v is False or v == "": continue if v is True: diff --git a/areal/engine/base_hf_engine.py b/areal/engine/base_hf_engine.py index f6095a18a..31b2fcc17 100644 --- a/areal/engine/base_hf_engine.py +++ b/areal/engine/base_hf_engine.py @@ -14,14 +14,11 @@ PretrainedConfig, PreTrainedTokenizerFast, ProcessorMixin, - get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, ) from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import TrainEngineConfig from areal.api.engine_api import TrainEngine -from areal.api.io_struct import FinetuneSpec from areal.platforms import current_platform from areal.utils import logging from areal.utils.data import ( @@ -35,7 +32,6 @@ unpack_sequence, unsqueeze_mb_list, ) -from areal.utils.fsdp import get_cosine_schedule_with_warmup from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer from areal.utils.model import ( disable_dropout_in_model, @@ -219,57 +215,6 @@ def _create_llm_actor_or_critic(self): ) return model - def create_optimizer(self, ft_spec: FinetuneSpec): - if self.optimizer_config is None: - return - assert self.model is not None - # Set up optimizer - tik = time.perf_counter() - assert ( - self.optimizer_config.type == "adam" - ), "Only AdamW optimizer is supported in this engine." - lr = self.optimizer_config.lr - weight_decay = self.optimizer_config.weight_decay - beta1 = self.optimizer_config.beta1 - beta2 = self.optimizer_config.beta2 - eps = self.optimizer_config.eps - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=lr, - weight_decay=weight_decay, - betas=(beta1, beta2), - eps=eps, - ) - total_train_steps = ft_spec.total_train_steps - num_warmup_steps = int( - self.optimizer_config.warmup_steps_proportion * total_train_steps - ) - - if self.optimizer_config.lr_scheduler_type == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - min_lr_ratio=self.optimizer_config.min_lr_ratio, - ) - elif self.optimizer_config.lr_scheduler_type == "linear": - self.lr_scheduler = get_linear_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - ) - elif self.optimizer_config.lr_scheduler_type == "constant": - self.lr_scheduler = get_constant_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - ) - else: - raise ValueError( - f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" - ) - self.logger.info(f"Create optimizer time: {time.perf_counter() - tik}") - def destroy(self): """Destroy the engine and release GPU memory.""" if hasattr(self, "optimizer"): diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 8b2b3fb11..f0bf97aa3 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -22,7 +22,13 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy from torch.distributed.tensor import DTensor -from transformers import PreTrainedTokenizerFast, ProcessorMixin +from transformers import ( + PreTrainedTokenizerFast, + ProcessorMixin, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, +) from areal.api.alloc_mode import FSDPParallelStrategy, ParallelStrategy from areal.api.cli_args import TrainEngineConfig @@ -41,6 +47,7 @@ from areal.utils.distributed import init_custom_process_group from areal.utils.fsdp import fsdp2_load_full_state_dict from areal.utils.fsdp.grad import fsdp2_clip_grad_norm +from areal.utils.fsdp.optimizer import AnyPrecisionAdamW from areal.utils.fsdp.parallel import ParallelHelper, parallelize_model from areal.utils.nccl import NCCL_DEFAULT_TIMEOUT from areal.utils.save_load import get_state_dict_from_repo_id_or_path @@ -715,3 +722,77 @@ def forward( unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) reordered = reorder_list(unpacked, mb_list.backward_indices) return pad_and_stack_tensors_along_first_dim(reordered) + + def create_optimizer(self, ft_spec: FinetuneSpec): + if self.optimizer_config is None: + return + assert self.model is not None + # Set up optimizer + tik = time.perf_counter() + assert self.optimizer_config.type in [ + "adam", + "adam_bf16", + "sgd", + ], "Only adam/adam_bf16/sgd optimizer is supported in this engine." + if self.optimizer_config.type in ["sgd", "adam_bf16"]: + self.logger.warning( + f"Using the '{self.optimizer_config.type}' optimizer with FSDP may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability and performance." + ) + lr = self.optimizer_config.lr + weight_decay = self.optimizer_config.weight_decay + beta1 = self.optimizer_config.beta1 + beta2 = self.optimizer_config.beta2 + eps = self.optimizer_config.eps + if self.optimizer_config.type == "adam": + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + fused=True, + ) + elif self.optimizer_config.type == "adam_bf16": + self.optimizer = AnyPrecisionAdamW( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + momentum_dtype="bfloat16", + variance_dtype="bfloat16", + ) + else: + self.optimizer = torch.optim.SGD( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + total_train_steps = ft_spec.total_train_steps + num_warmup_steps = int( + self.optimizer_config.warmup_steps_proportion * total_train_steps + ) + + if self.optimizer_config.lr_scheduler_type == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + min_lr_ratio=self.optimizer_config.min_lr_ratio, + ) + elif self.optimizer_config.lr_scheduler_type == "linear": + self.lr_scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + ) + elif self.optimizer_config.lr_scheduler_type == "constant": + self.lr_scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + ) + else: + raise ValueError( + f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" + ) + self.logger.info(f"Create optimizer time: {time.perf_counter() - tik}") diff --git a/areal/experimental/megatron_engine.py b/areal/experimental/megatron_engine.py index 5b14d4f7a..ffbddaeba 100644 --- a/areal/experimental/megatron_engine.py +++ b/areal/experimental/megatron_engine.py @@ -162,7 +162,6 @@ def initialize( ): model_config.param_sync_func = self.model.start_param_sync model_config.finalize_model_grads_func = finalize_model_grads - self.create_optimizer(ft_spec) def _make_parallel_strategy( @@ -241,13 +240,18 @@ def create_optimizer(self, ft_spec: FinetuneSpec): return assert self.model is not None - assert ( - self.optimizer_config.type == "adam" - ), "Only AdamW optimizer is supported in this engine." + assert self.optimizer_config.type in [ + "adam", + "sgd", + ], "Only AdamW/sgd optimizer is supported in this engine." + if self.optimizer_config.type == "sgd": + self.logger.warning( + f"Using the 'sgd' optimizer with Megatron may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability." + ) # Make megatron optimizer config mcore_opt_config = MCoreOptimizerConfig( - optimizer="adam", + optimizer=self.optimizer_config.type, lr=self.optimizer_config.lr, min_lr=self.optimizer_config.min_lr_ratio * self.optimizer_config.lr, weight_decay=self.optimizer_config.weight_decay, diff --git a/areal/utils/fsdp/optimizer.py b/areal/utils/fsdp/optimizer.py new file mode 100644 index 000000000..211acc528 --- /dev/null +++ b/areal/utils/fsdp/optimizer.py @@ -0,0 +1,179 @@ +from typing import List, Tuple + +import torch + + +def to_precision_dtype(dtype_str: str) -> torch.dtype: + """ + Convert string to corresponding torch dtype, only supports bfloat16 and float32. + + Args: + dtype_str: Data type string, supports "bfloat16" or "float32" + + Returns: + Corresponding torch dtype + + Raises: + ValueError: If the input dtype is not supported + """ + dtype_str = dtype_str.lower() + if dtype_str in ["bfloat16", "bf16"]: + return torch.bfloat16 + elif dtype_str in ["float32", "fp32"]: + return torch.float32 + else: + raise ValueError( + f"Unsupported dtype: {dtype_str}. Only 'bfloat16' and 'float32' are supported." + ) + + +# https://github.com/meta-llama/llama-cookbook/blob/v0.0.5/src/llama_cookbook/policies/anyprecision_optimizer.py +class AnyPrecisionAdamW(torch.optim.Optimizer): + def __init__( + self, + params: List[torch.Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + use_kahan_summation: bool = True, + momentum_dtype: str = "bfloat16", + variance_dtype: str = "bfloat16", + compensation_buffer_dtype: str = "bfloat16", + ): + """ + AnyPrecisionAdamW: a flexible precision AdamW optimizer + with optional Kahan summation for high precision weight updates. + Allows direct control over momentum, variance and auxiliary compensation buffer dtypes. + Optional Kahan summation is used to offset precision reduction for the weight updates. + This allows full training in BFloat16 (equal or better than FP32 results in many cases) + due to high precision weight updates. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # Any Precision specific + use_kahan_summation = creates auxiliary buffer to ensure high precision + model param updates (default: True) + momentum_dtype = dtype for momentum (default: bfloat16) + variance_dtype = dtype for uncentered variance (default: bfloat16) + compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + Defaults are variance in BF16, Momentum in BF16. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + """ + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "use_kahan_summation": use_kahan_summation, + "momentum_dtype": momentum_dtype, + "variance_dtype": variance_dtype, + "compensation_buffer_dtype": compensation_buffer_dtype, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + + if closure is not None: + with torch.enable_grad(): + closure() + + for group in self.param_groups: + beta1, beta2 = group["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + use_kahan_summation = group["use_kahan_summation"] + + momentum_dtype = to_precision_dtype(group["momentum_dtype"]) + variance_dtype = to_precision_dtype(group["variance_dtype"]) + compensation_buffer_dtype = to_precision_dtype( + group["compensation_buffer_dtype"] + ) + for p in group["params"]: + assert isinstance(p, torch.Tensor) # lint + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients." + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = torch.tensor(0.0) + + # momentum - EMA of gradient values + state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype) + + # variance uncentered - EMA of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype) + + # optional Kahan summation - accumulated error tracker + if use_kahan_summation: + state["compensation"] = torch.zeros_like( + p, dtype=compensation_buffer_dtype + ) + + # Main processing + # update the steps for each param group update + state["step"] += 1 + step = state["step"] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + grad = p.grad + + if weight_decay: # weight decay, AdamW style + p.data.mul_(1 - lr * weight_decay) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2 + ) # update uncentered variance + + bias_correction1 = 1 - beta1**step # adjust using bias1 + step_size = lr / bias_correction1 + + denom_correction = ( + 1 - beta2**step + ) ** 0.5 # adjust using bias2 and avoids math import + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) + + if use_kahan_summation: # lr update to compensation + compensation = state["compensation"] + compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) + + # update weights with compensation (Kahan summation) + # save error back to compensation for next iteration + temp_buffer = p.detach().clone() + p.data.add_(compensation) + compensation.add_(temp_buffer.sub_(p.data)) + else: # usual AdamW updates + p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 39293808f..adb2119fa 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -284,23 +284,23 @@ Configuration for reward/advantage normalization. Configuration for model optimization during training. -| Parameter | Type | Default | Description | -| ------------------------- | ------- | ------------ | ------------------------------------------------------------------------ | -| `type` | string | `"adam"` | Optimizer type **Choices:** `adam` | -| `lr` | float | `2e-05` | Learning rate | -| `weight_decay` | float | `0.05` | Weight decay | -| `beta1` | float | `0.9` | Adam beta1 parameter | -| `beta2` | float | `0.95` | Adam beta2 parameter | -| `eps` | float | `1e-05` | Adam epsilon parameter | -| `min_lr_ratio` | float | `0.0` | Minimum learning rate ratio after annealing | -| `lr_scheduler_type` | string | `"constant"` | Learning rate scheduler type **Choices:** `linear`, `cosine`, `constant` | -| `warmup_steps_proportion` | float | `0.001` | Proportion of training steps for warmup | -| `offload` | boolean | `False` | Enable optimizer state offloading | -| `initial_loss_scale` | float | `4294967296` | Initial loss scaling factor | -| `min_loss_scale` | float | `1.0` | Minimum loss scaling factor | -| `loss_scale_window` | float | `5` | Window size for loss scaling adjustment | -| `hysteresis` | integer | `2` | Hysteresis (scaling factor) for loss scaling | -| `gradient_clipping` | float | `1.0` | Gradient clipping threshold | +| Parameter | Type | Default | Description | +| ------------------------- | ------- | ------------ | ------------------------------------------------------------------------------------------------------- | +| `type` | string | `"adam"` | Optimizer type. Adam_bf16 currently only supported FSDP Engine. **Choices:** `adam`, `sgd`, `adam_bf16` | +| `lr` | float | `2e-05` | Learning rate | +| `weight_decay` | float | `0.05` | Weight decay | +| `beta1` | float | `0.9` | Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16 | +| `beta2` | float | `0.95` | Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16 | +| `eps` | float | `1e-05` | Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16 | +| `min_lr_ratio` | float | `0.0` | Minimum learning rate ratio after annealing | +| `lr_scheduler_type` | string | `"constant"` | Learning rate scheduler type **Choices:** `linear`, `cosine`, `constant` | +| `warmup_steps_proportion` | float | `0.001` | Proportion of training steps for warmup | +| `offload` | boolean | `False` | Enable optimizer state offloading | +| `initial_loss_scale` | float | `4294967296` | Initial loss scaling factor | +| `min_loss_scale` | float | `1.0` | Minimum loss scaling factor | +| `loss_scale_window` | float | `5` | Window size for loss scaling adjustment | +| `hysteresis` | integer | `2` | Hysteresis (scaling factor) for loss scaling | +| `gradient_clipping` | float | `1.0` | Gradient clipping threshold | (section-ppo-actor)=