diff --git a/examples/finetune.py b/examples/finetune.py index 8feb7ce17..8756f21df 100644 --- a/examples/finetune.py +++ b/examples/finetune.py @@ -13,7 +13,7 @@ foo = ClassFoo() bar = foo.FunctionBar() """ - +import os import sys from transformers import HfArgumentParser @@ -52,15 +52,31 @@ def main(): pipeline_args=pipeline_args, ) dataset = Dataset(data_args) - model = AutoModel.get_model(model_args) + + model = AutoModel.get_model( + model_args, + lang=data_args.lang, + forced_bos_token=data_args.forced_bos_token, + source_prefix = data_args.source_prefix, + streaming = data_args.streaming, + preprocessing_num_workers = data_args.preprocessing_num_workers, + overwrite_cache = data_args.overwrite_cache, + max_source_length = data_args.max_source_length, + max_target_length = data_args.max_target_length, + pad_to_max_length = data_args.pad_to_max_length + ) # Tokenization and text grouping must be done in the main process with pipeline_args.main_process_first(desc="dataset map tokenization"): tokenized_dataset = model.tokenize(dataset) - lm_dataset = finetuner.group_text( - tokenized_dataset, - model_max_length=model.get_max_length(), - ) + if model_args.arch_type == "encoder_decoder": + # encoder-decoder model does not need group text + lm_dataset = tokenized_dataset + else: + lm_dataset = finetuner.group_text( + tokenized_dataset, + model_max_length=model.get_max_length(), + ) # Finetuning tuned_model = finetuner.tune(model=model, lm_dataset=lm_dataset) diff --git a/requirements.txt b/requirements.txt index a1c11267b..ac99b5e52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,8 @@ wandb==0.14.0 deepspeed==0.8.3 trl @ git+https://github.com/lvwerra/trl.git#egg=trl-0.4.1 sentencepiece +icetk==0.0.7 +cpm_kernels==1.0.11 transformers @ git+https://github.com/huggingface/transformers@c612628 flask -flask_cors +flask_cors \ No newline at end of file diff --git a/scripts/run_chatbot_seq2seq.sh b/scripts/run_chatbot_seq2seq.sh new file mode 100755 index 000000000..5b729f934 --- /dev/null +++ b/scripts/run_chatbot_seq2seq.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +model=THUDM/chatglm-6b +lora_args="" +if [ $# -ge 1 ]; then + model=$1 +fi +if [ $# -ge 2 ]; then + lora_args="--lora_model_path $2" +fi + +CUDA_VISIBLE_DEVICES=0 \ + deepspeed examples/chatbot.py \ + --arch_type encoder_decoder \ + --deepspeed configs/ds_config_chatbot.json \ + --model_name_or_path ${model} \ + ${lora_args} diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 7d71c4d88..46fc35c12 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -13,13 +13,14 @@ """ from dataclasses import dataclass, field -from typing import Optional +from pathlib import Path +from typing import Optional, Union from transformers.utils.versions import require_version - +from transformers.generation.configuration_utils import GenerationConfig from transformers import ( MODEL_FOR_CAUSAL_LM_MAPPING, - TrainingArguments, + TrainingArguments ) MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) @@ -99,6 +100,10 @@ class ModelArguments: default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) + arch_type: Optional[str] = field( + default="decoder_only", + metadata={"help": "The architecture type of the model. Currently supported decoder_only or encoder_decoder"} + ) config_overrides: Optional[str] = field( default=None, metadata={ @@ -165,6 +170,15 @@ class ModelArguments: default=True, metadata={"help": "Whether use disk mapping when memory is not enough."} ) + resize_position_embeddings: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to automatically resize the position embeddings if `max_source_length` exceeds " + "the model's position embeddings." + ) + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -225,6 +239,8 @@ class DatasetArguments: each parameter, such as a help message. """ + lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) + dataset_path: Optional[str] = field( default=None, metadata={"help": "The path of the dataset to use."} ) @@ -309,6 +325,83 @@ class DatasetArguments: default=None, metadata={"help": "Evaluation File Path"}, ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": ( + "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + ) + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": ( + "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + ) + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The token to force as the first generated token after the decoder_start_token_id." + "Useful for multilingual models like mBART where the first generated token" + "needs to be the target language token (Usually it is the target language token)" + ) + }, + ) def __post_init__(self): if self.streaming: @@ -330,7 +423,54 @@ class FinetunerArguments(TrainingArguments): """ Adapt transformers.TrainingArguments """ - pass + + """ + Args: + sortish_sampler (`bool`, *optional*, defaults to `False`): + Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset* + for now but will become generally available in the near future. + + It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness + for the training set. + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + generation_max_length (`int`, *optional*): + The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `max_length` value of the model configuration. + generation_num_beams (`int`, *optional*): + The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `num_beams` value of the model configuration. + """ + + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + generation_max_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `max_length` value of the model configuration." + ) + }, + ) + generation_num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `num_beams` value of the model configuration." + ) + }, + ) + generation_config: Optional[Union[str, Path, GenerationConfig]] = field( + default=None, + metadata={ + "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction." + }, + ) + @dataclass diff --git a/src/lmflow/datasets/dataset.py b/src/lmflow/datasets/dataset.py index 394f143b3..4654a56f8 100644 --- a/src/lmflow/datasets/dataset.py +++ b/src/lmflow/datasets/dataset.py @@ -169,7 +169,7 @@ def from_dict(self, dict_obj: dict, *args, **kwargs): return self else: raise NotImplementedError( - f'Currently .from_dict is not supported for backend "{backend}"' + f'Currently .from_dict is not supported for backend "{self.backend}"' ) @@ -222,7 +222,7 @@ def to_dict(self): return dict_obj else: raise NotImplementedError( - f'Current .to_dict is not supported for backend "{backend}"' + f'Current .to_dict is not supported for backend "{self.backend}"' ) @@ -251,7 +251,7 @@ def map(self, *args, **kwargs): else: # If the backend is not Hugging Face, raise a NotImplementedError raise NotImplementedError( - f'Currently .map is not supported for backend "{backend}"' + f'Currently .map is not supported for backend "{self.backend}"' ) diff --git a/src/lmflow/models/auto_model.py b/src/lmflow/models/auto_model.py index 522b5aa53..459d304a4 100644 --- a/src/lmflow/models/auto_model.py +++ b/src/lmflow/models/auto_model.py @@ -4,11 +4,16 @@ """ from lmflow.models.hf_decoder_model import HFDecoderModel - +from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel class AutoModel: @classmethod def get_model(self, model_args, *args, **kwargs): # TODO (add new models) - return HFDecoderModel(model_args, *args, **kwargs) + if model_args.arch_type == "encoder_decoder": + return HFEncoderDecoderModel(model_args, *args, **kwargs) + elif model_args.arch_type == "decoder_only": + return HFDecoderModel(model_args, *args, **kwargs) + else: + raise NotImplementedError(f"Model type \"{model_args.arch_type}\" is not implemented.") diff --git a/src/lmflow/models/encoder_decoder_model.py b/src/lmflow/models/encoder_decoder_model.py new file mode 100644 index 000000000..9db0fc4a5 --- /dev/null +++ b/src/lmflow/models/encoder_decoder_model.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# coding=utf-8 +"""A one-line summary of the module or program, terminated by a period. + +Leave one blank line. The rest of this docstring should contain an +overall description of the module or program. Optionally, it may also +contain a brief desription of exported classes and funcctions and/or usage +examples. + +Typical usage example: + + foo = ClassFoo() + bar = foo.FunctionBar() +""" + +from lmflow.models.base_model import BaseModel + + +class EncoderDecoderModel(BaseModel): + + def __init__(self, *args, **kwargs): + pass diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index b20509dc7..570ecb8f6 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -49,10 +49,8 @@ from lmflow.models.decoder_model import DecoderModel from lmflow.models.interfaces.tunable import Tunable - logger = logging.getLogger(__name__) - class HFDecoderModel(DecoderModel, Tunable): r""" Initializes a HFDecoderModel instance. diff --git a/src/lmflow/models/hf_encoder_decoder_model.py b/src/lmflow/models/hf_encoder_decoder_model.py new file mode 100644 index 000000000..9ccec0508 --- /dev/null +++ b/src/lmflow/models/hf_encoder_decoder_model.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python +# coding=utf-8 +"""This is a class called HFEncoderDecoder which is a wrapper around transformers model and +tokenizer classes. It has several methods such as __init__, tokenize, and train that are +used for training and fine-tuning the model. The __init__ method takes in several arguments +such as model_args, tune_strategy, and ds_config, which are used to load the pretrained +model and tokenizer, and initialize the training settings. + +The tokenize method is used to tokenize the input text and return the input IDs and attention +masks that can be fed to the model for training or inference. + +This class supports different tune_strategy options such as 'normal', 'none', 'lora', and +'adapter', which allow for different fine-tuning settings of the model. However, the 'lora' +and 'adapter' strategies are not yet implemented. + +Overall, this class provides a convenient interface for loading and fine-tuning transformer +models and can be used for various NLP tasks such as language modeling, text classification, +and question answering. +""" +import logging +from typing import List, Union + +import deepspeed +from peft import ( + LoraConfig, + PeftModel, + TaskType, + get_peft_config, + get_peft_model, + prepare_model_for_int8_training, +) +import torch +import transformers +from transformers.deepspeed import HfDeepSpeedConfig + +from transformers.testing_utils import CaptureLogger + +from transformers import ( + CONFIG_MAPPING, + AutoConfig, + AutoModel, + AutoModelForSeq2SeqLM, + AutoTokenizer, + MBart50Tokenizer, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, +) +from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry +from lmflow.datasets.dataset import Dataset +from lmflow.models.encoder_decoder_model import EncoderDecoderModel +from lmflow.models.interfaces.tunable import Tunable + +logger = logging.getLogger(__name__) + +class HFEncoderDecoderModel(EncoderDecoderModel, Tunable): + r""" + Initializes a HFEncoderDecoderModel instance. + + Parameters + ------------ + + model_args : + Model arguments such as model name, path, revision, etc. + + tune_strategy : str or none, default="normal". + A string representing the dataset backend. Defaults to "huggingface". + + ds_config : + Deepspeed configuations. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + """ + + def __init__( + self, + model_args, + tune_strategy='normal', + ds_config=None, + *args, + **kwargs + ): + """ + Initializes a HFEncoderDecoderModel instance. + :param model_args: dictionary with model arguments such as model name, path, revision, etc. + :param tune_strategy: tuning strategy: normal, none, lora or adapter + :param ds_config: deepspeed configuration for distributed training + """ + + # See more about loading any type of standard or custom dataset (from + # files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: The .from_pretrained methods guarantee that + # only one local process can concurrently download model & vocab. + + data_args = kwargs + self.data_args = data_args + self.model_args = model_args + + if tune_strategy == 'normal': + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, trust_remote_code=True, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + logger.info(f"New config: {config}") + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, trust_remote_code=True, **tokenizer_kwargs) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, **tokenizer_kwargs) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is" + " not supported by this script. You can do it from another" + " script, save it, and load it from here, using" + " --tokenizer_name." + ) + + if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=torch_dtype, + trust_remote_code=True + ) + else: + model = AutoModelForSeq2SeqLM.from_config(config) + n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") + + if model_args.use_lora: + peft_config = LoraConfig( + task_type=TaskType.SEQ_2_SEQ_LM, + inference_mode=False, + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args["lang"]] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args["lang"]) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + max_source_length = data_args["max_source_length"] + if ( + hasattr(model.config, "max_position_embeddings") + and model.config.max_position_embeddings < max_source_length + ): + if model_args.resize_position_embeddings is None: + logger.warning( + "Increasing the model's number of position embedding vectors from" + f" {model.config.max_position_embeddings} to {max_source_length}." + ) + model.resize_position_embeddings(max_source_length) + elif model_args.resize_position_embeddings: + model.resize_position_embeddings(max_source_length) + else: + raise ValueError( + f"`--max_source_length` is set to {max_source_length}, but the model only has" + f" {model.config.max_position_embeddings} position encodings. Consider either reducing" + f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the" + " model's position encodings by passing `--resize_position_embeddings`." + ) + + self.config = config + self.backend_model = model + self.tokenizer = tokenizer + self.tune_strategy = tune_strategy + + elif tune_strategy == 'none': + dschf = HfDeepSpeedConfig(ds_config) + if model_args.model_name_or_path == 'THUDM/chatglm-6b': + self.backend_model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + else: + self.backend_model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + peft_model_id = model_args.lora_model_path + if peft_model_id is not None: + self.backend_model = PeftModel.from_pretrained( + self.backend_model, peft_model_id + ) + + deepspeed.init_distributed() + self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0] + self.ds_engine.module.eval() + + elif tune_strategy == 'adapter': + raise NotImplementedError('adapter tune strategy not implemented') + + + def tokenize(self, dataset, *args, **kwargs): + """ + Tokenize the full dataset. + + Parameters + ------------ + dataset : + Text dataset. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + tokenized_datasets : + The tokenized dataset. + """ + model_args = self.model_args + data_args = self.data_args + text_column = "input" + summary_column = "output" + prefix = data_args["source_prefix"] if data_args["source_prefix"] is not None else "" + + # A list of all multilingual tokenizer which require lang attribute. + MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] + + if isinstance(self.tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): + assert ( + data_args["lang"] is not None + ), f"{self.tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" + + self.tokenizer.src_lang = data_args["lang"] + self.tokenizer.tgt_lang = data_args["lang"] + + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. + forced_bos_token_id = ( + self.tokenizer.lang_code_to_id[data_args["forced_bos_token"]] if data_args["forced_bos_token"] is not None else None + ) + self.model.config.forced_bos_token_id = forced_bos_token_id + + # Temporarily set max_target_length for training. + max_target_length = data_args["max_target_length"] + padding = "max_length" if data_args["pad_to_max_length"] else False + + # Preprocessing the datasets. + # First we tokenize all the texts. + if dataset.get_backend() != "huggingface": + raise NotImplementedError( + "tokenization of datasets with non-huggingface backend are" + "not supported yet" + ) + + # TODO: DO WE NEED THIS? + # since this will be pickled to avoid _LazyModule error in Hasher force + # logger loading before tokenize_function + tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") + if model_args.use_lora: + self.tokenizer.pad_token = 1 + + raw_datasets = dataset + hf_raw_datasets = dataset.get_backend_dataset() + column_names = list(hf_raw_datasets.features) + text_column_name = "text" if "text" in column_names else column_names[0] + + def preprocess_function(examples): + # remove pairs where at least one record is None + + inputs, targets = [], [] + for i in range(len(examples[text_column])): + if examples[text_column][i] and examples[summary_column][i]: + inputs.append(examples[text_column][i]) + targets.append(examples[summary_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = self.tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + + # Tokenize targets with the `text_target` keyword argument + labels = self.tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + data_args = raw_datasets.get_data_args() + + tokenized_datasets = raw_datasets.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + + return tokenized_datasets + + + def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> List[int]: + """ + Perform encoding process of the tokenizer. + + Parameters + ------------ + inputs : str or list. + The text sequence. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + outputs : + The tokenized inputs. + """ + return self.tokenizer.encode(text=input, *args, **kwargs) + + + def decode(self, input, *args, **kwargs ) -> List[int]: + """ + Perform decoding process of the tokenizer. + + Parameters + ------------ + inputs : list. + The token sequence. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + outputs : + The text decoded from the token inputs. + """ + return self.tokenizer.decode(input, *args, **kwargs) + + + def inference(self, inputs, *args, **kwargs): + """ + Perform generation process of the model. + + Parameters + ------------ + inputs : + The sequence used as a prompt for the generation or as model inputs to the model. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + outputs : + The generated sequence output + """ + + + with torch.no_grad(): + outputs = self.ds_engine.module.generate( + input_ids=inputs, + synced_gpus=True, + pad_token_id=self.tokenizer.eos_token_id, + *args, + **kwargs + ) + return outputs + + + def get_max_length(self): + """ + Return max acceptable input length in terms of tokens. + """ + return self.tokenizer.model_max_length + + + def get_tokenizer(self): + """ + Return the tokenizer of the model. + """ + return self.tokenizer + + + def get_backend_model(self): + """ + Return the backend model. + """ + return self.backend_model diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index b6350aeee..6a00da16a 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -3,6 +3,8 @@ """The Finetuner class simplifies the process of running finetuning process on a language model for a TunableModel instance with given dataset. """ +from __future__ import absolute_import + import logging import os import sys @@ -13,6 +15,8 @@ from itertools import chain from transformers import ( Trainer, + Seq2SeqTrainer, + DataCollatorForSeq2Seq, default_data_collator, set_seed, ) @@ -21,7 +25,6 @@ from lmflow.datasets.dataset import Dataset from lmflow.pipeline.base_tuner import BaseTuner - logger = logging.getLogger(__name__) @@ -104,7 +107,6 @@ def __init__(self, model_args, data_args, finetuner_args, *args, **kwargs): # Set seed before initializing model. set_seed(finetuner_args.seed) - def group_text(self, tokenized_datasets, model_max_length): """ Groups texts together to form blocks of maximum length `model_max_length` and returns the processed data as @@ -200,6 +202,12 @@ def tune(self, model, lm_dataset): data_args = self.data_args finetuner_args = self.finetuner_args + if finetuner_args.label_smoothing_factor > 0 and not hasattr(model.get_backend_model(), "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.get_backend_model().__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + train_dataset = lm_dataset.get_backend_dataset() if finetuner_args.do_train: @@ -209,17 +217,55 @@ def tune(self, model, lm_dataset): # Initialize our Trainer training_args = finetuner_args - trainer = Trainer( - model=model.get_backend_model(), - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=None, - tokenizer=model.get_tokenizer(), - # Data collator will default to DataCollatorWithPadding, so we change it. - data_collator=default_data_collator, - compute_metrics=None, - preprocess_logits_for_metrics=None, - ) + + if model_args.arch_type == "encoder_decoder": + # Data collator + tokenizer = model.get_tokenizer() + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model.get_backend_model(), + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model.get_backend_model(), + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=None, + ) + + elif model_args.arch_type == "decoder_only": + trainer = Trainer( + model=model.get_backend_model(), + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=None, + tokenizer=model.get_tokenizer(), + # Data collator will default to DataCollatorWithPadding, so we change it. + data_collator=default_data_collator, + compute_metrics=None, + preprocess_logits_for_metrics=None, + ) + + else: + raise NotImplementedError( + f"Model type \"{model_args.arch_type}\" is not implemented." + ) # Training if training_args.do_train: diff --git a/src/lmflow/pipeline/inferencer.py b/src/lmflow/pipeline/inferencer.py index 2457df86f..c8687abd9 100644 --- a/src/lmflow/pipeline/inferencer.py +++ b/src/lmflow/pipeline/inferencer.py @@ -50,7 +50,7 @@ def __init__(self, model_args, data_args, inferencer_args): torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error deepspeed.init_distributed() - self.config = AutoConfig.from_pretrained(model_args.model_name_or_path) + self.config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) try: self.model_hidden_size = self.config.hidden_size except: