Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import uvloop
import yaml

from areal.utils.pkg_version import is_version_less

uvloop.install()
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
Expand Down Expand Up @@ -162,13 +164,31 @@ class OptimizerConfig:

type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam"]},
metadata={
"help": "Optimizer type. Adam_bf16 currently only supported FSDP Engine.",
"choices": ["adam", "sgd", "adam_bf16"],
},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
beta1: float = field(
default=0.9,
metadata={
"help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
beta2: float = field(
default=0.95,
metadata={
"help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
eps: float = field(
default=1e-5,
metadata={
"help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
min_lr_ratio: float = field(
default=0.0,
metadata={
Expand Down Expand Up @@ -632,6 +652,8 @@ def build_cmd(
# convert to flags
flags = []
for k, v in args.items():
if is_version_less("sglang", "0.4.10.post2") and "max_loaded_loras" in k:
continue
if v is None or v is False or v == "":
continue
if v is True:
Expand Down
55 changes: 0 additions & 55 deletions areal/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
PretrainedConfig,
PreTrainedTokenizerFast,
ProcessorMixin,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)

from areal.api.alloc_mode import ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
from areal.api.engine_api import TrainEngine
from areal.api.io_struct import FinetuneSpec
from areal.platforms import current_platform
from areal.utils import logging
from areal.utils.data import (
Expand All @@ -35,7 +32,6 @@
unpack_sequence,
unsqueeze_mb_list,
)
from areal.utils.fsdp import get_cosine_schedule_with_warmup
from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer
from areal.utils.model import (
disable_dropout_in_model,
Expand Down Expand Up @@ -219,57 +215,6 @@ def _create_llm_actor_or_critic(self):
)
return model

def create_optimizer(self, ft_spec: FinetuneSpec):
if self.optimizer_config is None:
return
assert self.model is not None
# Set up optimizer
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps

self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)

if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
self.logger.info(f"Create optimizer time: {time.perf_counter() - tik}")

def destroy(self):
"""Destroy the engine and release GPU memory."""
if hasattr(self, "optimizer"):
Expand Down
83 changes: 82 additions & 1 deletion areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy
from torch.distributed.tensor import DTensor
from transformers import PreTrainedTokenizerFast, ProcessorMixin
from transformers import (
PreTrainedTokenizerFast,
ProcessorMixin,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)

from areal.api.alloc_mode import FSDPParallelStrategy, ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
Expand All @@ -41,6 +47,7 @@
from areal.utils.distributed import init_custom_process_group
from areal.utils.fsdp import fsdp2_load_full_state_dict
from areal.utils.fsdp.grad import fsdp2_clip_grad_norm
from areal.utils.fsdp.optimizer import AnyPrecisionAdamW
from areal.utils.fsdp.parallel import ParallelHelper, parallelize_model
from areal.utils.nccl import NCCL_DEFAULT_TIMEOUT
from areal.utils.save_load import get_state_dict_from_repo_id_or_path
Expand Down Expand Up @@ -715,3 +722,77 @@ def forward(
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

def create_optimizer(self, ft_spec: FinetuneSpec):
if self.optimizer_config is None:
return
assert self.model is not None
# Set up optimizer
tik = time.perf_counter()
assert self.optimizer_config.type in [
"adam",
"adam_bf16",
"sgd",
], "Only adam/adam_bf16/sgd optimizer is supported in this engine."
if self.optimizer_config.type in ["sgd", "adam_bf16"]:
self.logger.warning(
f"Using the '{self.optimizer_config.type}' optimizer with FSDP may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability and performance."
)
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
if self.optimizer_config.type == "adam":
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
fused=True,
)
elif self.optimizer_config.type == "adam_bf16":
self.optimizer = AnyPrecisionAdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
momentum_dtype="bfloat16",
variance_dtype="bfloat16",
)
else:
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)

if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
self.logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
14 changes: 9 additions & 5 deletions areal/experimental/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def initialize(
):
model_config.param_sync_func = self.model.start_param_sync
model_config.finalize_model_grads_func = finalize_model_grads

self.create_optimizer(ft_spec)

def _make_parallel_strategy(
Expand Down Expand Up @@ -241,13 +240,18 @@ def create_optimizer(self, ft_spec: FinetuneSpec):
return
assert self.model is not None

assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
assert self.optimizer_config.type in [
"adam",
"sgd",
], "Only AdamW/sgd optimizer is supported in this engine."
if self.optimizer_config.type == "sgd":
self.logger.warning(
f"Using the 'sgd' optimizer with Megatron may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability."
)

# Make megatron optimizer config
mcore_opt_config = MCoreOptimizerConfig(
optimizer="adam",
optimizer=self.optimizer_config.type,
lr=self.optimizer_config.lr,
min_lr=self.optimizer_config.min_lr_ratio * self.optimizer_config.lr,
weight_decay=self.optimizer_config.weight_decay,
Expand Down
Loading
Loading