diff --git a/mindone/trainers/muon.py b/mindone/trainers/muon.py new file mode 100644 index 0000000000..cbc14207a6 --- /dev/null +++ b/mindone/trainers/muon.py @@ -0,0 +1,309 @@ +"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" +import math +from typing import List, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint as mint +import mindspore.ops as ops +from mindspore import Parameter, ParameterTuple, Tensor +from mindspore.experimental.optim.optimizer import Optimizer + +_muon_opt = ops.MultitypeFuncGraph("muon_opt") + + +@_muon_opt.register( + "Float", + "Float", + "Float", + "Float", + "Bool", + "Int", + "Float", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Float", + "Bool", +) +def _update_run_op( + mu: float, + beta1: float, + beta2: float, + eps: float, + nesterov: bool, + ns_steps: int, + weight_decay: float, + lr: Parameter, + denom: Parameter, + param: Parameter, + m: Parameter, + v: Parameter, + g: Tensor, + ratio: float, + use_muon: bool, +) -> bool: + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + if use_muon: + m.mul_(mu).add_(g) + if nesterov: + g = g.add(m, alpha=mu) + else: + g = m + g = zeropower_via_newtonschulz5(g, steps=ns_steps) + param.add_(lr * g, alpha=-ratio) + else: + m_next = mint.lerp(g, m, beta1) + v_next = mint.lerp(mint.square(g), v, beta2) + g = m_next / (eps + mint.sqrt(v_next)) + param.add_(-(lr / denom) * g) + ops.assign(m, m_next) + ops.assign(v, v_next) + return True + + +_qk_clip_opt = ops.MultitypeFuncGraph("qk_clip_opt") + + +@_qk_clip_opt.register("Float", "Int", "Tensor", "Tensor", "Tensor") +def _update_clip_op( + clip_value: float, qk_nope_head_dim: int, qk: Tensor, q_b_projs: Parameter, kv_b_projs: Parameter +) -> bool: + qk = mint.transpose(qk, 0, 1).flatten(start_dim=1) + qk_max, _ = mint.max(qk, dim=1) + num_head = qk_max.shape[0] + scale = mint.clip(clip_value / qk_max, max=1.0) + scale = scale[:, None, None] + scale_sqrt = mint.sqrt(scale) + # clip Q projection + outdim, _ = q_b_projs.shape + head_dim = outdim // num_head + scale_q_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1)) + scale_q_b_rope = mint.tile(scale, (1, head_dim - qk_nope_head_dim, 1)) + scale_q_b = mint.cat([scale_q_b_nope, scale_q_b_rope], dim=1) + q_b_projs.mul_(scale_q_b.view(-1, 1)) + # clip K projection + outdim, _ = kv_b_projs.shape + head_dim = outdim // num_head + scale_kv_b_nope = mint.tile(scale_sqrt, (1, qk_nope_head_dim, 1)) + scale_kv_b_rope = mint.ones((num_head, head_dim - qk_nope_head_dim, 1), dtype=scale_sqrt.dtype) + scale_kv_b = mint.cat([scale_kv_b_nope, scale_kv_b_rope], dim=1) + kv_b_projs.mul_(scale_kv_b.view(-1, 1)) + return True + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + shape = G.shape + + if len(shape) > 2: + G = G.view(G.shape[0], -1) + assert len(shape) == 2 + + a, b, c = 3.4445, -4.7750, 2.0315 + X = G.bfloat16() + if G.shape[0] > G.shape[1]: + X = mint.t(X) + + # Ensure spectral norm is at most 1 + X = X / (mint.norm(X) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = mint.matmul(X, X.T) + B = mint.addmm(A, A, A, beta=b, alpha=c) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = mint.addmm(X, B, X, beta=a) + + if G.shape[0] > G.shape[1]: + X = mint.t(X) + + if len(shape) > 2: + X = X.view(*shape) + return X + + +class Muon(Optimizer): + """Following https://github.com/MoonshotAI/Moonlight""" + + def __init__( + self, + lr: Union[float, Tensor] = 1e-3, + wd: float = 0.1, + muon_params: Optional[List[Parameter]] = None, + momentum: float = 0.95, + nesterov: bool = True, + ns_steps: int = 5, + adamw_params: Optional[List[Parameter]] = None, + adamw_betas: Tuple[float, float] = (0.9, 0.95), + adamw_eps: float = 1e-8, + clip_value: Optional[float] = 100.0, + qk_nope_head_dim: int = 64, + ) -> None: + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + self.clip_value = clip_value + self.qk_nope_head_dim = qk_nope_head_dim + # Sort parameters into those for which we will use Muon, and those for which we will not + use_muon = list() + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim >= 2, p.ndim + use_muon.append(True) + + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + use_muon.append(False) + self.use_muon = tuple(use_muon) + + self.exp_avg = self.parameters.clone("exp_avg", init="zeros") + self.exp_avg_sq = ParameterTuple( + [ + ( + Parameter(mint.zeros(x.shape, dtype=x.dtype), name="exp_avg_sq." + x.name) + if not use_muon + else Parameter([], name="exp_avg_sq." + x.name) + ) + for x, use_muon in zip(self.parameters, self.use_muon) + ] + ) + + self.lr_ratio = tuple([self._cal_lr_ratio(x, use_muon) for x, use_muon in zip(self.parameters, self.use_muon)]) + + self.state_step = Parameter(Tensor(0, dtype=ms.int32)) + self.increase_tensor = Tensor(1, dtype=ms.int32) + self.denom = Parameter(Tensor(1.0, dtype=ms.float32)) + + if self.clip_value is not None: + # group the Q and KV projection first for easier updating in QK-clip + # TODO: it should be extracted from optimizer as extra inputs + q_b_projs = [] + kv_b_projs = [] + for x in self.parameters: + if x.name.endswith("q_b_proj.weight"): + layer_idx = int(x.name.split(".")[2]) + q_b_projs.append((layer_idx, x)) + elif x.name.endswith("kv_b_proj.weight"): + layer_idx = int(x.name.split(".")[2]) + kv_b_projs.append((layer_idx, x)) + q_b_projs = sorted(q_b_projs, key=lambda x: x[0]) + kv_b_projs = sorted(kv_b_projs, key=lambda x: x[0]) + self.q_b_projs = ParameterTuple([x[1] for x in q_b_projs]) + self.kv_b_projs = ParameterTuple([x[1] for x in kv_b_projs]) + assert len(self.q_b_projs) > 0 and len(self.kv_b_projs) > 0 + + def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2) -> float: + if not use_muon: + return 1.0 + + A, B = param.shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = rms_scale * math.sqrt(max(A, B)) + return adjusted_ratio + + @ms.jit(jit_level="O1") + def muon( + self, + momentum: float, + beta1: float, + beta2: float, + eps: float, + nesterov: bool, + ns_steps: int, + weight_decay: float, + lr: Parameter, + gradients: Tuple[Tensor, ...], + ratio: Tuple[float, ...], + use_muon: Tuple[bool, ...], + start_id: int, + end_id: int, + ) -> bool: + bias_correction1 = 1 - beta1**self.state_step + bias_correction2 = 1 - beta2**self.state_step + ops.assign(self.denom, bias_correction1 / bias_correction2**0.5) + + optim_result = self.hyper_map( + ops.partial( + _muon_opt, + momentum, + beta1, + beta2, + eps, + nesterov, + ns_steps, + weight_decay, + lr, + self.denom, + ), + self.parameters[start_id:end_id], + self.exp_avg[start_id:end_id], + self.exp_avg_sq[start_id:end_id], + gradients[start_id:end_id], + ratio[start_id:end_id], + use_muon[start_id:end_id], + ) + return optim_result + + @ms.jit(jit_level="O1") + def qk_clip(self, qk_products: Tuple[Tensor, ...]) -> bool: + optim_result = self.hyper_map( + ops.partial(_qk_clip_opt, self.clip_value, self.qk_nope_head_dim), + qk_products, + self.q_b_projs, + self.kv_b_projs, + ) + return optim_result + + def construct(self, gradients: Tuple[Tensor, ...], qk_products: Optional[Tuple[Tensor, ...]] = None) -> bool: + if self.clip_value is not None: + assert qk_products is not None + + self.state_step.add_(self.increase_tensor) + for group_id, group in enumerate(self.param_groups): + beta1, beta2 = group["adamw_betas"] + start_id = self.group_start_id[group_id] + end_id = self.group_start_id[group_id + 1] + + self.muon( + group["momentum"], + beta1, + beta2, + group["adamw_eps"], + group["nesterov"], + group["ns_steps"], + group["weight_decay"], + group["lr"], + gradients, + self.lr_ratio, + self.use_muon, + start_id, + end_id, + ) + + if self.clip_value is None: + return True + else: + optim_result = self.qk_clip(qk_products) + return optim_result diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py index bf02d54ad7..86e6c50532 100644 --- a/mindone/transformers/models/qwen2/modeling_qwen2.py +++ b/mindone/transformers/models/qwen2/modeling_qwen2.py @@ -12,6 +12,7 @@ """ import math +import os from typing import Callable, List, Optional, Tuple, Union from transformers import Qwen2Config, logging @@ -21,6 +22,7 @@ from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from mindone.models.utils import normal_, zeros_ from mindone.transformers.cache_utils import Cache, get_max_length, get_seq_length, update from mindone.transformers.generation import GenerationMixin from mindone.transformers.mindspore_adapter import str_to_dtype @@ -80,6 +82,9 @@ def __init__(self, config: Qwen2Config, device=None): rope_type = "default" rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + if os.environ.get("USE_MLA", None) == "1": + logger.info("Use MLA attention.") + config.head_dim = config.hidden_size // config.num_attention_heads // 2 inv_freq, self.attention_scaling = rope_init_fn(config) self.inv_freq = inv_freq @@ -174,7 +179,8 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) / mint.sqrt(ms.tensor(module.head_dim)) + qk_product = mint.matmul(query, key_states.swapaxes(2, 3)) + attn_weights = qk_product / mint.sqrt(ms.tensor(module.head_dim)) if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -184,7 +190,7 @@ def eager_attention_forward( attn_output = mint.matmul(attn_weights, value_states) attn_output = attn_output.swapaxes(1, 2).contiguous() - return attn_output, attn_weights + return attn_output, attn_weights, qk_product class Qwen2Attention(nn.Cell): @@ -297,6 +303,139 @@ def construct( return attn_output, attn_weights, past_key_value +class Qwen2MLAAttention(nn.Cell): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.q_lora_rank = config.intermediate_size // 14 + self.qk_nope_head_dim = self.head_dim + self.qk_rope_head_dim = self.head_dim // 2 + self.v_head_dim = self.head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.kv_lora_rank = config.hidden_size // 14 + + self.q_a_proj = nn.Dense(config.hidden_size, self.q_lora_rank, has_bias=True) + self.q_a_layernorm = Qwen2RMSNorm(self.q_lora_rank) + self.q_b_proj = nn.Dense(self.q_lora_rank, self.num_heads * self.qk_head_dim, has_bias=True) + + self.kv_a_proj_with_mqa = nn.Dense( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + has_bias=True, + ) + self.kv_a_layernorm = Qwen2RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Dense( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + has_bias=True, + ) + + self.o_proj = nn.Dense(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=False) + + self.rotary_emb = Qwen2RotaryEmbedding(config) + + self.scale = self.head_dim**-0.5 + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[Tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + query_states = query_states.view(bsz, q_len, self.num_heads, self.qk_head_dim).transpose(1, 2) + query_pass, query_rot = mint.split(query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + key_pass, key_rot = mint.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + key_pass = self.kv_b_proj(self.kv_a_layernorm(key_pass)) + key_pass = key_pass.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2) + key_pass, value_states = mint.split(key_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + key_rot = key_rot.view(bsz, 1, q_len, self.qk_rope_head_dim) + + cos, sin = position_embeddings + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + key_rot = key_rot.expand((*key_pass.shape[:-1], -1)) + + query_states = mint.cat((query_pass, query_rot), dim=-1) + key_states = mint.cat((key_pass, key_rot), dim=-1) + + if past_key_value is not None: + key_states, value_states = update(past_key_value, key_states, value_states, cache_position) + past_key_value = (key_states, value_states) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and not output_attentions: + logger.warning_once( + "`mindspore.ops.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attn_output, attn_weights, qk_product = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + sliding_window=sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, qk_product + + class Qwen2PageAttention(Qwen2Attention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays @@ -377,11 +516,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = ( - Qwen2Attention(config, layer_idx) - if not config._attn_implementation == "paged_attention" - else Qwen2PageAttention(config=config, layer_idx=layer_idx) - ) + + if config._attn_implementation == "paged_attention": + self.self_attn = Qwen2PageAttention(config=config, layer_idx=layer_idx) + elif os.environ.get("USE_MLA", None) == "1": + self.self_attn = Qwen2MLAAttention(config=config, layer_idx=layer_idx) + else: + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -432,7 +573,7 @@ def construct( # Self Attention if block_tables is None: - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, present_key_value, qk_product = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -476,6 +617,8 @@ def construct( if use_cache: outputs += (present_key_value,) + outputs += (qk_product,) + return outputs @@ -508,16 +651,15 @@ class Qwen2PreTrainedModel(MSPreTrainedModel): _supports_attention_backend = True def _init_weights(self, module): - # std = self.config.initializer_range - # if isinstance(module, nn.Dense): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.bias is not None: - # module.bias.data.zero_() - # elif isinstance(module, nn.Embedding): - # module.weight.data.normal_(mean=0.0, std=std) - # if module.padding_idx is not None: - # module.weight.data[module.padding_idx].zero_() - pass + std = self.config.initializer_range + if isinstance(module, nn.Dense): + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, nn.Embedding): + normal_(module.embedding_table, mean=0.0, std=std) + if module.padding_idx is not None: + module.embedding_table[module.padding_idx] = 0 QWEN2_INPUTS_DOCSTRING = r""" @@ -686,6 +828,7 @@ def construct( position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers + all_qk_products = () all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_caches = () if use_cache else None @@ -717,6 +860,8 @@ def construct( if output_attentions: all_self_attns += (layer_outputs[1],) + all_qk_products += (layer_outputs[-1],) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -724,7 +869,11 @@ def construct( all_hidden_states += (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, next_caches, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [hidden_states, next_caches, all_hidden_states, all_self_attns, all_qk_products] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/tests/trainer_tests/muon/mindspore/toy_train_ms.py b/tests/trainer_tests/muon/mindspore/toy_train_ms.py new file mode 100644 index 0000000000..3d22f73346 --- /dev/null +++ b/tests/trainer_tests/muon/mindspore/toy_train_ms.py @@ -0,0 +1,186 @@ +"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" + +import os +import time +from functools import partial + +import numpy as np +from datasets import load_dataset +from loguru import logger +from tqdm import tqdm +from transformers import Qwen2Config, Qwen2Tokenizer +from transformers.optimization import _get_cosine_schedule_with_warmup_lr_lambda + +import mindspore as ms +import mindspore.mint as mint +from mindspore.dataset import GeneratorDataset +from mindspore.experimental.optim import AdamW, Optimizer +from mindspore.experimental.optim.lr_scheduler import LambdaLR + +from mindone.trainers.muon import Muon +from mindone.transformers import Qwen2ForCausalLM + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class MoonDataset: + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.npy"): + self.tokens = np.load(f"{self.dataset_name}.npy") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + np.save(f"{self.dataset_name}.npy", self.tokens) + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = np.asarray(token_slice, dtype=np.int32) + return data + + +def get_model_and_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name]) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + # mike: default shuffle = True, for comparison set it to be False + train_loader = GeneratorDataset(train_dataset, column_names="input_ids", shuffle=True).batch(8) + + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model, train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1, clip_value=None): + if optimizer_name == "adamw": + return AdamW(model.get_parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.parameters_and_names() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.parameters_and_names() + if not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + clip_value=clip_value, + ) + else: + assert 0, "optimizer not supported" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + parser.add_argument("--clip_value", type=float, default=None) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + + ms.set_seed(0) + model, train_loader = get_model_and_dataloader(args.model, args.dataset, args.hidden_size) + optimizer = get_optimizer(args.optimizer, model, lr=args.lr, clip_value=args.clip_value) + + model.set_train(True) + epoch = 1 + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + + total_train_params = sum([x.numel() for x in optimizer.parameters]) + logger.info(f"Total number of trainable parameters: {total_train_params:,}") + + grad_fn = ms.value_and_grad(model, grad_position=None, weights=optimizer.parameters, has_aux=True) + for epoch in range(epoch): + for step, batch in enumerate(train_loader.create_tuple_iterator()): + (input_ids,) = batch + (loss, _, qk_products), grads = grad_fn(input_ids=input_ids, labels=input_ids, return_dict=False) + qk_products_max = max([mint.max(x).item() for x in qk_products]) + logger.info(f"QK max value: {qk_products_max:.3f}") + ms.synchronize() + start = time.time() + optimizer(grads, qk_products) + ms.synchronize() + duration = time.time() - start + lr_scheduler.step() + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr'].item():.5f} " + f"Optimizer update time: {duration:.3f} Training loss: {loss.item()}" + ) diff --git a/tests/trainer_tests/muon/torch/toy_train.py b/tests/trainer_tests/muon/torch/toy_train.py new file mode 100644 index 0000000000..905ec30218 --- /dev/null +++ b/tests/trainer_tests/muon/torch/toy_train.py @@ -0,0 +1,349 @@ +"""Copied from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" + +import math +import os +import time + +import torch +from datasets import load_dataset +from loguru import logger +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Tokenizer, get_cosine_schedule_with_warmup + + +class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model_and_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name]) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + # mike: default shuffle = True, for comparison set it to be False + train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False) + + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model, train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + + model, train_loader = get_model_and_dataloader(args.model, args.dataset, args.hidden_size) + optimizer = get_optimizer(args.optimizer, model, lr=args.lr) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + model.train() + epoch = 1 + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(device) + input_ids = batch + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + torch.cuda.synchronize() + start = time.time() + optimizer.step() + torch.cuda.synchronize() + duration = time.time() - start + lr_scheduler.step() + optimizer.zero_grad() + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']:.5f} Optimizer update time: {duration:.3f} Training loss: {loss.item()}" + )