diff --git a/lazyllm/common/common.py b/lazyllm/common/common.py index 1a798f79d..d9b95a64c 100644 --- a/lazyllm/common/common.py +++ b/lazyllm/common/common.py @@ -31,15 +31,19 @@ def absorb(self, item): class ArgsDict(dict): - def __init__(self, *args, **kwargs): + def __init__(self, *args, with_line=True, **kwargs): super(ArgsDict, self).__init__(*args, **kwargs) + self._with_line = with_line def check_and_update(self, kw): assert set(kw.keys()).issubset(set(self)), f'unexpected keys: {set(kw.keys()) - set(self)}' self.update(kw) def parse_kwargs(self): - string = ' '.join(f'--{k}={v}' if type(v) is not str else f'--{k}=\"{v}\"' for k, v in self.items()) + if self._with_line: + string = ' '.join(f'--{k}={v}' if type(v) is not str else f'--{k}=\"{v}\"' for k, v in self.items()) + else: + string = ' '.join(f'{k}={v}' if type(v) is not str else f'{k}=\"{v}\"' for k, v in self.items()) return string class CaseInsensitiveDict(dict): diff --git a/lazyllm/components/deploy/vllm.py b/lazyllm/components/deploy/vllm.py index 8a2c0bced..7526e0880 100644 --- a/lazyllm/components/deploy/vllm.py +++ b/lazyllm/components/deploy/vllm.py @@ -72,12 +72,17 @@ def __init__(self, trust_remote_code: bool = True, launcher: LazyLLMLaunchersBas ray_launcher[0], post_action=(lazyllm.parallel(*parall_launcher) if len(parall_launcher) else None)) def cmd(self, finetuned_model=None, base_model=None, master_ip=None): - if not os.path.exists(finetuned_model) or \ - not any(filename.endswith('.bin') or filename.endswith('.safetensors') - for filename in os.listdir(finetuned_model)): - if not finetuned_model: - LOG.warning(f"Note! That finetuned_model({finetuned_model}) is an invalid path, " - f"base_model({base_model}) will be used") + if not finetuned_model: + LOG.warning(f"Note! finetuned_model is empty, using base_model({base_model}) instead.") + finetuned_model = base_model + elif not os.path.exists(finetuned_model): + LOG.warning(f"Warning! The finetuned_model path does not exist: {finetuned_model}. " + f"Using base_model({base_model}) instead.") + finetuned_model = base_model + elif not any(filename.endswith(('.bin', '.safetensors', '.pt')) + for filename in os.listdir(finetuned_model)): + LOG.warning(f"Warning! No valid model files (.bin, .safetensors or .pt) found in: {finetuned_model}. " + f"Using base_model({base_model}) instead.") finetuned_model = base_model def impl(): diff --git a/lazyllm/components/finetune/__init__.py b/lazyllm/components/finetune/__init__.py index 8c52570b6..65d968801 100644 --- a/lazyllm/components/finetune/__init__.py +++ b/lazyllm/components/finetune/__init__.py @@ -3,6 +3,7 @@ from .collie import CollieFinetune from .llamafactory import LlamafactoryFinetune from .flagembedding import FlagembeddingFinetune +from .easyr1 import EasyR1Finetune __all__ = [ 'LazyLLMFinetuneBase', @@ -10,4 +11,5 @@ 'CollieFinetune', 'LlamafactoryFinetune', 'FlagembeddingFinetune', + 'EasyR1Finetune', ] diff --git a/lazyllm/components/finetune/easy_r1/config.yaml b/lazyllm/components/finetune/easy_r1/config.yaml new file mode 100644 index 000000000..ddab28279 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/config.yaml @@ -0,0 +1,103 @@ +data: + train_files: hiyouga/math12k@train + val_files: hiyouga/math12k@test + prompt_key: problem + answer_key: answer + image_key: images + video_key: videos + image_dir: null + video_fps: 2.0 + max_prompt_length: 2048 + max_response_length: 2048 + rollout_batch_size: 512 # equivalent to verl's data.train_batch_size + mini_rollout_batch_size: null # equivalent to verl's data.gen_batch_size + val_batch_size: 1024 + format_prompt: ./examples/format_prompt/math.jinja + override_chat_template: null + shuffle: true + seed: 1 + min_pixels: 262144 + max_pixels: 4194304 + filter_overlong_prompts: true + +algorithm: + adv_estimator: grpo + disable_kl: false + use_kl_loss: true + kl_penalty: low_var_kl + kl_coef: 1.0e-2 + online_filtering: false # dapo filter groups + filter_key: overall + filter_low: 0.01 + filter_high: 0.99 + +worker: + actor: + global_batch_size: 128 # equivalent to verl's actor.ppo_mini_batch_size + micro_batch_size_per_device_for_update: 4 # equivalent to verl's actor.ppo_micro_batch_size_per_gpu + micro_batch_size_per_device_for_experience: 16 # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu + max_grad_norm: 1.0 + padding_free: true + ulysses_size: 1 + model: + model_path: /mnt/lustre/share_data/sunxiaoye/models/qwen2.5-0.5b-instruct + enable_gradient_checkpointing: true + trust_remote_code: false + freeze_vision_tower: false + optim: + lr: 1.0e-6 + weight_decay: 1.0e-2 + strategy: adamw # {adamw, adamw_bf16} + lr_warmup_ratio: 0.0 + fsdp: + enable_full_shard: true + enable_cpu_offload: false + enable_rank0_init: true + offload: + offload_params: true # true: more CPU memory; false: more GPU memory + offload_optimizer: true # true: more CPU memory; false: more GPU memory + + rollout: + n: 5 + temperature: 1.0 + top_p: 0.99 + limit_images: 0 + gpu_memory_utilization: 0.6 + enforce_eager: false + enable_chunked_prefill: false + tensor_parallel_size: 1 + disable_tqdm: false + val_override_config: + temperature: 0.5 + n: 1 + + ref: + fsdp: + enable_full_shard: true + enable_cpu_offload: true # true: more CPU memory; false: more GPU memory + enable_rank0_init: true + offload: + offload_params: false + + reward: + reward_type: batch + reward_function: ./examples/reward_function/math.py:compute_score + +trainer: + total_epochs: 15 + max_steps: null + project_name: easy_r1 + experiment_name: qwen2_5_7b_math_grpo + logger: ["console"] + nnodes: 1 + n_gpus_per_node: 1 + max_try_make_batch: 20 # -1 means no limit + val_freq: 5 # -1 to disable + val_before_train: true + val_only: false + val_generations_to_log: 3 + save_freq: 5 # -1 to disable + save_limit: 3 # -1 to disable + save_model_only: false + save_checkpoint_path: null + load_checkpoint_path: null diff --git a/lazyllm/components/finetune/easy_r1/format_prompt/dapo.jinja b/lazyllm/components/finetune/easy_r1/format_prompt/dapo.jinja new file mode 100644 index 000000000..ea56a6a64 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/format_prompt/dapo.jinja @@ -0,0 +1 @@ +Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n\n{{ content | trim }}\n\nRemember to put your answer on its own line after "Answer:". diff --git a/lazyllm/components/finetune/easy_r1/format_prompt/math.jinja b/lazyllm/components/finetune/easy_r1/format_prompt/math.jinja new file mode 100644 index 000000000..8d6aa2344 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/format_prompt/math.jinja @@ -0,0 +1 @@ +{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}. diff --git a/lazyllm/components/finetune/easy_r1/format_prompt/r1v.jinja b/lazyllm/components/finetune/easy_r1/format_prompt/r1v.jinja new file mode 100644 index 000000000..0ecf6f471 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/format_prompt/r1v.jinja @@ -0,0 +1 @@ +{{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here diff --git a/lazyllm/components/finetune/easy_r1/model_merger.py b/lazyllm/components/finetune/easy_r1/model_merger.py new file mode 100644 index 000000000..5053c3f67 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/model_merger.py @@ -0,0 +1,188 @@ +# flake8: noqa +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Tuple + +import numpy as np +import torch +from torch.distributed._tensor import DTensor, Placement, Shard +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + PretrainedConfig, + PreTrainedModel, +) + + +def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + else: + raise ValueError(f"Unsupported placement: {placement}") + + +def upload_model_to_huggingface(local_path: str, remote_path: str): + # Push to hugging face + from huggingface_hub import HfApi + + api = HfApi() + api.create_repo(repo_id=remote_path, private=False, exist_ok=True) + api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") + parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") + args = parser.parse_args() + local_dir: str = args.local_dir + + assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface." + + # copy rank zero to find the shape of (dp, fsdp) + rank = 0 + world_size = 0 + for filename in os.listdir(local_dir): + match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) + if match: + world_size = match.group(1) + break + + assert world_size, "No model file with the proper format." + + rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") + state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False) + pivot_key = sorted(state_dict.keys())[0] + weight = state_dict[pivot_key] + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([int(world_size)], dtype=np.int64) + mesh_dim_names = ("fsdp",) + + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}." + + if "tp" in mesh_dim_names: + # fsdp * tp + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + # fsdp + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + print(f"Processing {total_shards} model shards in total.") + model_state_dict_lst = [] + model_state_dict_lst.append(state_dict) + model_state_dict_lst.extend([""] * (total_shards - 1)) + + def process_one_shard(rank, model_state_dict_lst): + model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + for rank in range(1, total_shards): + executor.submit(process_one_shard, rank, model_state_dict_lst) + + state_dict: Dict[str, List[torch.Tensor]] = {} + param_placements: Dict[str, List[Placement]] = {} + keys = set(model_state_dict_lst[0].keys()) + for key in keys: + state_dict[key] = [] + for model_state_dict in model_state_dict_lst: + try: + tensor = model_state_dict.pop(key) + except Exception: + print(f"Cannot find key {key} in rank {rank}.") + + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + placements = tuple(tensor.placements) + # replicated placement at ddp dimension can be discarded + if mesh_dim_names[0] == "ddp": + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) + + del model_state_dict_lst + + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + + if key in param_placements: + # merge shards + placements: Tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet.") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) + + print("Merge completed.") + hf_path = os.path.join(local_dir, "huggingface") + config: PretrainedConfig = AutoConfig.from_pretrained(hf_path) + architectures: List[str] = getattr(config, "architectures", ["Unknown"]) + + if "ForTokenClassification" in architectures[0]: + AutoClass = AutoModelForTokenClassification + elif "ForCausalLM" in architectures[0]: + AutoClass = AutoModelForCausalLM + elif "ForConditionalGeneration" in architectures[0]: + AutoClass = AutoModelForVision2Seq + else: + raise NotImplementedError(f"Unknown architecture {architectures}.") + + with torch.device("meta"): + model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16) + + assert isinstance(model, PreTrainedModel) + model.to_empty(device="cpu") + + print(f"Saving model to {hf_path}...") + model.save_pretrained(hf_path, state_dict=state_dict) + del state_dict, model + + if args.hf_upload_path: + upload_model_to_huggingface(hf_path, args.hf_upload_path) diff --git a/lazyllm/components/finetune/easy_r1/reward_function/dapo.py b/lazyllm/components/finetune/easy_r1/reward_function/dapo.py new file mode 100644 index 000000000..35c791667 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/reward_function/dapo.py @@ -0,0 +1,59 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +from mathruler.grader import extract_boxed_content, grade_answer + + +def accuracy_reward(response: str, ground_truth: str) -> float: + answer = extract_boxed_content(response) + return 1.0 if grade_answer(answer, ground_truth) else -1.0 + + +def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int): + expected_len = max_response_length - overlong_buffer_length + if response_length <= expected_len: + return 0.0 + elif response_length <= max_response_length: + return (expected_len - response_length) / overlong_buffer_length + else: + return -1.0 + + +def compute_score( + reward_inputs: List[Dict[str, Any]], + max_response_length: int, + overlong_buffer_length: int, + overlong_penalty_factor: float, +) -> List[Dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for dapo reward function.") + + scores = [] + for reward_input in reward_inputs: + accuracy_score = accuracy_reward(reward_input["response"], reward_input["ground_truth"]) + overlong_score = soft_overlong_punishment( + reward_input["response_length"], max_response_length, overlong_buffer_length + ) + scores.append( + { + "overall": accuracy_score + overlong_score * overlong_penalty_factor, + "accuracy": accuracy_score, + "overlong": overlong_score, + "accuracy_normalized": 0.5 * (accuracy_score + 1.0), + } + ) + + return scores diff --git a/lazyllm/components/finetune/easy_r1/reward_function/math.py b/lazyllm/components/finetune/easy_r1/reward_function/math.py new file mode 100644 index 000000000..f41ffe13f --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/reward_function/math.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List + +from mathruler.grader import extract_boxed_content, grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + answer = extract_boxed_content(response) + return 1.0 if grade_answer(answer, ground_truth) else 0.0 + + +def compute_score(reward_inputs: List[Dict[str, Any]], format_weight: float = 0.1) -> List[Dict[str, float]]: + if not isinstance(reward_inputs, list): + raise ValueError("Please use `reward_type=batch` for math reward function.") + + scores = [] + for reward_input in reward_inputs: + response = re.sub(r"\s*(<|>|/)\s*", r"\1", reward_input["response"]) # handle qwen2.5vl-32b format + format_score = format_reward(response) + accuracy_score = accuracy_reward(response, reward_input["ground_truth"]) + scores.append( + { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } + ) + + return scores diff --git a/lazyllm/components/finetune/easy_r1/reward_function/r1v.py b/lazyllm/components/finetune/easy_r1/reward_function/r1v.py new file mode 100644 index 000000000..97b03c794 --- /dev/null +++ b/lazyllm/components/finetune/easy_r1/reward_function/r1v.py @@ -0,0 +1,50 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict + +from mathruler.grader import grade_answer + + +def format_reward(response: str) -> float: + pattern = re.compile(r".*?\s*.*?", re.DOTALL) + format_match = re.fullmatch(pattern, response) + return 1.0 if format_match else 0.0 + + +def accuracy_reward(response: str, ground_truth: str) -> float: + try: + content_match = re.search(r"(.*?)", response) + given_answer = content_match.group(1).strip() if content_match else response.strip() + if grade_answer(given_answer, ground_truth.strip()): + return 1.0 + + except Exception: + pass + + return 0.0 + + +def compute_score(reward_input: Dict[str, Any], format_weight: float = 0.5) -> Dict[str, float]: + if not isinstance(reward_input, dict): + raise ValueError("Please use `reward_type=sequential` for r1v reward function.") + + format_score = format_reward(reward_input["response"]) + accuracy_score = accuracy_reward(reward_input["response"], reward_input["ground_truth"]) + return { + "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "format": format_score, + "accuracy": accuracy_score, + } diff --git a/lazyllm/components/finetune/easyr1.py b/lazyllm/components/finetune/easyr1.py new file mode 100644 index 000000000..b7f68d957 --- /dev/null +++ b/lazyllm/components/finetune/easyr1.py @@ -0,0 +1,126 @@ +import os +import copy +import json +import random +import subprocess +from datetime import datetime +from subprocess import CalledProcessError + +import lazyllm +from lazyllm import launchers, ArgsDict, thirdparty, LOG +from .base import LazyLLMFinetuneBase + + +class EasyR1Finetune(LazyLLMFinetuneBase): + defatult_kw = ArgsDict({ + 'data.max_prompt_length': 2048, + 'data.max_response_length': 2048, + 'data.rollout_batch_size': 128, + 'data.val_batch_size': 1024, + 'data.format_prompt': None, + 'worker.actor.global_batch_size': 128, + 'worker.actor.micro_batch_size_per_device_for_update': 4, + 'worker.actor.micro_batch_size_per_device_for_experience': 16, + 'worker.rollout.gpu_memory_utilization': 0.6, + 'worker.rollout.tensor_parallel_size': 1, + 'worker.reward.reward_function': None, + 'trainer.total_epochs': 2, + 'trainer.n_gpus_per_node': 1, + 'trainer.save_freq': 5, + 'trainer.save_checkpoint_path': None, + 'trainer.save_model_only': False, + }, with_line=False) + + def __init__(self, + base_model, + target_path, + merge_path=None, + launcher=launchers.remote(ngpus=1, sync=True), + **kw + ): + if not merge_path: + merge_path = target_path + os.makedirs(target_path, exist_ok=True) + os.makedirs(merge_path, exist_ok=True) + super().__init__( + base_model, + target_path, + launcher=launcher, + ) + self._folder_path = os.path.dirname(os.path.abspath(__file__)) + self.kw = copy.deepcopy(self.defatult_kw) + self.kw.check_and_update(kw) + + def cmd(self, trainset, valset=None) -> str: + thirdparty.check_packages(['verl', 'trl']) + if not os.path.exists(trainset): + defatult_path = os.path.join(lazyllm.config['data_path'], trainset) + if os.path.exists(defatult_path): + trainset = defatult_path + else: + raise FileNotFoundError(f"Trainset {trainset} does not exist, please check your path.") + if not os.path.exists(valset): + defatult_path = os.path.join(lazyllm.config['data_path'], valset) + if os.path.exists(defatult_path): + valset = defatult_path + else: + raise FileNotFoundError(f"Valset {valset} does not exist, please check your path.") + + formatted_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + random_value = random.randint(1000, 9999) + self.log_file_path = f'{self.target_path}/train_log_{formatted_date}_{random_value}.log' + + self.kw['data.train_files'] = trainset + self.kw['data.val_files'] = valset + self.kw['worker.actor.model.model_path'] = self.base_model + self.kw['trainer.n_gpus_per_node'] = self.launcher.ngpus + if not self.kw['trainer.save_checkpoint_path']: + self.kw['trainer.save_checkpoint_path'] = self.target_path + if not self.kw['worker.reward.reward_function']: + self.kw['worker.reward.reward_function'] = (f'{self._folder_path}/easy_r1/' + 'reward_function/math.py:compute_score') + if not self.kw['data.format_prompt']: + self.kw['data.format_prompt'] = f'{self._folder_path}/easy_r1/format_prompt/math.jinja' + + cmd = f'python -m verl.trainer.main config={self._folder_path}/easy_r1/config.yaml ' + cmd += self.kw.parse_kwargs() + cmd += f' 2>&1 | tee {self.log_file_path}' + + return cmd + + def __call__(self, *args, **kw): + save_path = super().__call__(*args, **kw) + ckpt_tracker_file = os.path.join(save_path, 'checkpoint_tracker.json') + if not os.path.exists(ckpt_tracker_file): + not_found_msg = 'Training failed, checkpoint_tracker.json not found.' + LOG.error(not_found_msg) + return not_found_msg + with open(ckpt_tracker_file, 'r') as f: + json_data = json.load(f) + actor_path = json_data.get('last_actor_path', None) + if not actor_path or not os.path.exists(actor_path): + not_found_msg = 'Training failed, last_actor_path not found in checkpoint_tracker.json.' + LOG.error(not_found_msg) + return not_found_msg + + self.merge_ckpt(actor_path) + huggingface_path = os.path.join(actor_path, 'huggingface') + return huggingface_path + + def merge_ckpt(self, path): + try: + script_path = f"{self._folder_path}/easy_r1/model_merger.py" + subprocess.run( + ["python", script_path, "--local_dir", path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + except CalledProcessError as e: + LOG.error(f"Command failed with return code {e.returncode}: {e.stderr}") + return False + except Exception as e: + LOG.error(f"Error merging checkpoints: {e}") + return False + return True diff --git a/tests/advanced_tests/standard_test/test_finetune.py b/tests/advanced_tests/standard_test/test_finetune.py index 4dc41c8ed..8ef337613 100644 --- a/tests/advanced_tests/standard_test/test_finetune.py +++ b/tests/advanced_tests/standard_test/test_finetune.py @@ -16,6 +16,9 @@ def setup_method(self): self.llm_path = 'qwen1.5-0.5b-chat' self.vlm_data = 'ci_data/vqa_rad/train.json' self.vlm_path = 'qwen2.5-vl-3b-instruct' + self.grpo_train_data = 'ci_data/math-json-200/train200.json' + self.grpo_test_data = 'ci_data/math-json-200/test100.json' + self.grpo_llm = 'qwen2-0.5b-instruct' self.embed_data = os.path.join(lazyllm.config['data_path'], 'sft_embeding/embedding.json') self.embed_path = 'bge-m3' self.rerank_data = os.path.join(lazyllm.config['data_path'], 'sft_embeding/rerank.jsonl') @@ -91,3 +94,22 @@ def test_finetune_reranker(self): assert type(res) is list assert len(res) == 2 assert type(res[0]) is tuple + + def test_grpo_easyr1(self): + m = lazyllm.TrainableModule(self.grpo_llm, self.save_path)\ + .mode('finetune')\ + .trainset(lambda: lazyllm.package(self.grpo_train_data, self.grpo_test_data))\ + .finetune_method( + (lazyllm.finetune.easyr1, { + 'data.rollout_batch_size': 64, + 'data.val_batch_size': 32, + 'worker.actor.global_batch_size': 32, + 'trainer.save_model_only': True, + 'trainer.total_epochs': 1, + 'worker.rollout.tensor_parallel_size': 2, + 'launcher': lazyllm.launchers.remote(ngpus=2, sync=True), + })) + m.update() + assert self.has_bin_file(m.finetuned_model_path) + res = m('hi') + assert type(res) is str