diff --git a/generate.py b/generate.py index f8a24f82..d61c7d82 100644 --- a/generate.py +++ b/generate.py @@ -1,16 +1,13 @@ import torch -import torch.nn.functional as F +import torch.nn.functional as functional import os import argparse from tqdm import trange from transformers import GPT2LMHeadModel, GPT2Config, BertTokenizer -def is_word(word): - for item in list(word): - if item not in "qwertyuiopasdfghjklzxcvbnm": - return False - return True +def is_word(word: str): + return word.isalpha() def _is_chinese_char(char): @@ -22,17 +19,17 @@ def _is_chinese_char(char): # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled - # like the all of the other languages. + # like the all the other languages. cp = ord(char) if ( - (cp >= 0x4E00 and cp <= 0x9FFF) - or (cp >= 0x3400 and cp <= 0x4DBF) # - or (cp >= 0x20000 and cp <= 0x2A6DF) # - or (cp >= 0x2A700 and cp <= 0x2B73F) # - or (cp >= 0x2B740 and cp <= 0x2B81F) # - or (cp >= 0x2B820 and cp <= 0x2CEAF) # - or (cp >= 0xF900 and cp <= 0xFAFF) - or (cp >= 0x2F800 and cp <= 0x2FA1F) # + (0x4E00 <= cp <= 0x9FFF) + or (0x3400 <= cp <= 0x4DBF) # + or (0x20000 <= cp <= 0x2A6DF) # + or (0x2A700 <= cp <= 0x2B73F) # + or (0x2B740 <= cp <= 0x2B81F) # + or (0x2B820 <= cp <= 0x2CEAF) # + or (0xF900 <= cp <= 0xFAFF) + or (0x2F800 <= cp <= 0x2FA1F) # ): # return True @@ -41,15 +38,18 @@ def _is_chinese_char(char): def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). + :param top_k: + :param top_p: + :param filter_value: + :param logits: + logits distribution shape (vocabulary size) + top_k > 0: keep only top k tokens with the highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Nucleus filtering is described in Holtzman et al. (https://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ assert ( - logits.dim() == 1 + logits.dim() == 1 ) # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: @@ -59,7 +59,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf") if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + cumulative_probs = torch.cumsum(functional.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p @@ -73,43 +73,43 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf") def sample_sequence( - model, - context, - length, - n_ctx, - tokenizer, - temperature=1.0, - top_k=30, - top_p=0.0, - repitition_penalty=1.0, - device="cpu", + model, + context, + length, + n_ctx, + tokenizer, + temperature=1.0, + top_k=30, + top_p=0.0, + repetition_penalty=1.0, + device="cpu", ): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0) generated = context with torch.no_grad(): for _ in trange(length): - inputs = {"input_ids": generated[0][-(n_ctx - 1) :].unsqueeze(0)} + inputs = {"input_ids": generated[0][-(n_ctx - 1):].unsqueeze(0)} outputs = model( **inputs - ) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) + ) # Note: we could also use 'past' with GPT-2/Transformer-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] - for id in set(generated): - next_token_logits[id] /= repitition_penalty + for idx in set(generated): + next_token_logits[idx] /= repetition_penalty next_token_logits = next_token_logits / temperature next_token_logits[tokenizer.convert_tokens_to_ids("[UNK]")] = -float("Inf") filtered_logits = top_k_top_p_filtering( next_token_logits, top_k=top_k, top_p=top_p ) next_token = torch.multinomial( - F.softmax(filtered_logits, dim=-1), num_samples=1 + functional.softmax(filtered_logits, dim=-1), num_samples=1 ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) return generated.tolist()[0] def fast_sample_sequence( - model, context, length, temperature=1.0, top_k=30, top_p=0.0, device="cpu" + model, context, length, temperature=1.0, top_k=30, top_p=0.0, device="cpu" ): inputs = torch.LongTensor(context).view(1, -1).to(device) if len(context) > 1: @@ -120,7 +120,7 @@ def fast_sample_sequence( prev = inputs generate = [] + context with torch.no_grad(): - for i in trange(length): + for _ in trange(length): output = model(prev, past=past) output, past = output[:2] output = output[-1].squeeze(0) / temperature @@ -142,13 +142,13 @@ def main(): "--batch_size", default=1, type=int, required=False, help="生成的batch size" ) parser.add_argument( - "--nsamples", default=10, type=int, required=False, help="生成几个样本" + "--n_samples", default=10, type=int, required=False, help="生成几个样本" ) parser.add_argument( "--temperature", default=1, type=float, required=False, help="生成温度" ) - parser.add_argument("--topk", default=8, type=int, required=False, help="最高几选一") - parser.add_argument("--topp", default=0, type=float, required=False, help="最高积累概率") + parser.add_argument("--top_k", default=8, type=int, required=False, help="最高几选一") + parser.add_argument("--top_p", default=0, type=float, required=False, help="最高积累概率") parser.add_argument( "--model_config", default="config/model_config.json", @@ -188,11 +188,11 @@ def main(): os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 length = args.length n_ctx = args.n_ctx - batch_size = args.batch_size - nsamples = args.nsamples + # batch_size = args.batch_size + n_samples = args.n_samples temperature = args.temperature - topk = args.topk - topp = args.topp + top_k = args.top_k + top_p = args.top_p repetition_penalty = args.repetition_penalty device = "cuda" if torch.cuda.is_available() else "cpu" @@ -209,7 +209,7 @@ def main(): model.to(device) model.eval() - for i in range(nsamples): + for i in range(n_samples): raw_text = args.prefix encoded = tokenizer.encode_plus(raw_text)["input_ids"][:-1] out = sample_sequence( @@ -219,9 +219,9 @@ def main(): n_ctx=n_ctx, tokenizer=tokenizer, temperature=temperature, - top_k=topk, - top_p=topp, - repitition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, device=device, ) print(tokenizer.decode(out)) diff --git a/gpt2c.py b/gpt2c.py new file mode 100644 index 00000000..9eb6ebb9 --- /dev/null +++ b/gpt2c.py @@ -0,0 +1,45 @@ +from abc import ABC + +from transformers import GPT2Model +from transformers.models.gpt2.modeling_gpt2 import ( + BaseModelOutputWithPastAndCrossAttentions, + torch, + Optional, Tuple, Union +) + + +class GPT2ModelWithAdditionalHiddenStates(GPT2Model, ABC): + def __init__(self, config): + super().__init__(config) + + def forward(self, additional_hidden_states: Optional[torch.FloatTensor] = None, + **kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + """ + :param additional_hidden_states: shape = [batch, seq_length, embed_size] + """ + if 'input_ids' in kwargs: + input_ids = kwargs.pop('input_ids') + inputs_embeds: torch.FloatTensor = self.wte(input_ids) + else: + inputs_embeds: torch.FloatTensor = kwargs.pop('inputs_embeds') + + assert inputs_embeds.size() == additional_hidden_states.size() + + kwargs['inputs_embeds'] = inputs_embeds + additional_hidden_states + + return super().forward(**kwargs) + + +class GPT2ModelWithContext(GPT2ModelWithAdditionalHiddenStates, ABC): + def __init__(self, config): + super().__init__(config) + + def context_to_embedding(self, context: dict) -> torch.FloatTensor: + """ + :return: additional_hidden_states, shape = [batch, seq_length, embed_size] + """ + raise NotImplementedError + + def forward(self, context: dict = None, **kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + additional_hidden_states: torch.FloatTensor = self.context_to_embedding(context) + return super().forward(additional_hidden_states, **kwargs) diff --git a/requirements.txt b/requirements.txt index 9ad1f7a3..02282538 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -pytorch-lightning==1.2.2 -transformers==4.2.1 -torch==1.8.0 -tqdm==4.56.0 \ No newline at end of file +# python~=3.10.8 +pytorch-lightning~=1.8.4.post0 +torch~=1.12.0 +transformers~=4.25.1 +tqdm~=4.65.0 diff --git a/scripts/generate.sh b/scripts/generate.sh index 2a97a247..1d9819f5 100644 --- a/scripts/generate.sh +++ b/scripts/generate.sh @@ -4,5 +4,5 @@ python generate.py \ --tokenizer_path cache/vocab_small.txt \ --model_path model/final_model \ --prefix "[CLS][MASK]" \ - --topp 1 \ + --top_p 1 \ --temperature 1.0 diff --git a/tests.py b/tests.py new file mode 100644 index 00000000..5d348016 --- /dev/null +++ b/tests.py @@ -0,0 +1,77 @@ +import unittest +from transformers import BertTokenizer, GPT2Config, GPT2LMHeadModel, TextGenerationPipeline +from collections import OrderedDict +import torch +from typing import List + + +class TestMain(unittest.TestCase): + def setUp(self) -> None: + # self.model_config_path: str = 'config/model_config_test.json' + # self.model_config_path: str = 'config/model_config_small.json' + self.model_config_path: str = 'config/model_config.json' + self.vocab_path: str = 'vocab/vocab.txt' + self.config: GPT2Config = GPT2Config.from_json_file(self.model_config_path) + + def test_train(self): + import sys + from train import main + args = [ + '--data_path', 'data/train.txt', + '--batch_size', '2', + # '--devices', '0' + '--config_path', self.model_config_path, + # '--config_path', 'config/model_config_test.json', + '--epochs', '2' + ] + sys.argv.extend(args) + main() + self.assertTrue(True) + + @classmethod + def pipeline(cls, model: GPT2LMHeadModel, tokenizer: BertTokenizer, text: str) -> str: + pad_token_id = tokenizer('[PAD]')['input_ids'][1] + pipeline = TextGenerationPipeline(model, tokenizer) + result = pipeline(text, max_length=100, pad_token_id=pad_token_id, do_sample=True) + return result + + @classmethod + def generate(cls, model: GPT2LMHeadModel, tokenizer: BertTokenizer, text: str) -> str: + pad_token_id = tokenizer('[PAD]')['input_ids'][1] + input_ids = tokenizer('[CLS]' + text, return_tensors='pt', padding=False, add_special_tokens=False)['input_ids'] + output_ids: torch.Tensor = model.generate(input_ids, max_length=100, pad_token_id=pad_token_id, do_sample=True) + output_tokens: List[str] = tokenizer.convert_ids_to_tokens(output_ids[0]) + output_text: str = ''.join(filter(lambda x: x not in ['[SEP]', '[PAD]', '[CLS]'], output_tokens)) + return output_text + + def load_from_checkpoint(self, checkpoint_path: str) -> GPT2LMHeadModel: + model: GPT2LMHeadModel = GPT2LMHeadModel(config=self.config) + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + raw_state_dict: OrderedDict = checkpoint["state_dict"] + state_dict: OrderedDict = OrderedDict({k.replace('model.', ''): v for k, v in raw_state_dict.items()}) + model.load_state_dict(state_dict) + return model + + def test_generate(self): + # load from checkpoint + # checkpoint_path: str = 'model/epoch=1-step=862.ckpt' + # model = self.load_from_checkpoint(checkpoint_path) + # tokenizer: BertTokenizer = BertTokenizer(vocab_file=self.vocab_path) + + # load using transformers api + model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained('model/gpt2-chinese-cluecorpussmall') + tokenizer: BertTokenizer = BertTokenizer.from_pretrained('model/gpt2-chinese-cluecorpussmall') + + model.eval() + + text: str = '我叫' + + output_text: str = self.generate(model, tokenizer, text) + # output_text: str = self.pipeline(model, tokenizer, text) + print(output_text) + + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/train.py b/train.py index e03cfe6d..13785e4c 100644 --- a/train.py +++ b/train.py @@ -1,235 +1,248 @@ -from transformers import GPT2LMHeadModel, GPT2Config -from transformers import AdamW, get_linear_schedule_with_warmup, BertTokenizer -from torch.utils.data import Dataset, DataLoader -from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor -import pytorch_lightning as pl -import torch -import json -import argparse - -# 11846807 - - -class DS(Dataset): - def __init__(self, lines, vocab_path="vocab/vocab.txt", max_length=1024): - self.data = lines - self.tok = BertTokenizer(vocab_file=vocab_path) - self.max_length = max_length - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - line = self.data[index] - line = self.tok.encode_plus( - line, - max_length=self.max_length, - truncation=True, - padding="max_length", - return_tensors="pt", - ) - return line - - -class Net(pl.LightningModule): - def __init__( - self, - batch_size, - epochs, - t_total=100000, - config_path="config/model_config.json", - data_path="data/train.json", - valid_examples=100, - vocab_path="vocab/vocab.txt", - max_length=1024, - warm_up_steps=0, - lr=1e-4, - ): - super(Net, self).__init__() - self.batch_size = batch_size - self.epochs = epochs - self.t_total = t_total - self.warm_up_steps = warm_up_steps - self.lr = lr - self.model_name = "bert_pretrained_model" - self.config = GPT2Config.from_json_file(config_path) - self.model = GPT2LMHeadModel(config=self.config) - self.data = [json.loads(line.strip()) for line in open(data_path)] - self.dataset_train = DS( - self.data[:-valid_examples], vocab_path=vocab_path, max_length=max_length - ) - self.dataset_valid = DS( - self.data[-valid_examples:], vocab_path=vocab_path, max_length=max_length - ) - - def forward(self, input_ids, attention_mask): - input_ids = input_ids - attention_mask = attention_mask - r = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=input_ids, - return_dict=True, - ) - return r["loss"] - - def train_dataloader(self): - return DataLoader( - self.dataset_train, - batch_size=self.batch_size, - num_workers=8, - shuffle=True, - drop_last=True, - ) - - def val_dataloader(self): - return DataLoader( - self.dataset_valid, - batch_size=self.batch_size, - num_workers=8, - shuffle=True, - drop_last=True, - ) - - def configure_optimizers(self): - optimizer = AdamW(self.parameters(), lr=self.lr, weight_decay=0.001) - scheduler = get_linear_schedule_with_warmup( - optimizer, self.warm_up_steps, self.t_total - ) - scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} - return [optimizer], [scheduler] - - def training_step(self, batch, batch_nb): - loss = self.forward(batch["input_ids"], batch["attention_mask"]) - - self.log( - "train_loss", - loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def validation_step(self, batch, batch_nb): - loss = self.forward(batch["input_ids"], batch["attention_mask"]) - return loss - - def validation_epoch_end(self, outputs): - avg_loss = torch.stack(outputs).mean() - self.log( - "val_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return {"val_loss": avg_loss} - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument( - "--device", default="0", type=str, required=False, help="设置使用哪些显卡,用逗号分割" - ) - parser.add_argument( - "--config_path", - default="config/model_config.json", - type=str, - required=False, - help="选择模型参数", - ) - parser.add_argument( - "--vocab_path", - default="vocab/vocab.txt", - type=str, - required=False, - help="选择词库", - ) - parser.add_argument( - "--data_path", - default="data/train.json", - type=str, - required=False, - help="原始训练语料", - ) - parser.add_argument("--epochs", default=5, type=int, required=False, help="训练循环") - parser.add_argument( - "--batch_size", default=8, type=int, required=False, help="训练batch size" - ) - parser.add_argument("--lr", default=1.5e-4, type=float, required=False, help="学习率") - parser.add_argument( - "--warmup_steps", default=2000, type=int, required=False, help="warm up步数" - ) - parser.add_argument( - "--max_length", default=1024, type=int, required=False, help="单条文本最长长度" - ) - parser.add_argument( - "--eval_interval", default=100, type=int, required=False, help="eval 步数" - ) - parser.add_argument( - "--val_examples", default=100, type=int, required=False, help="选择多少验证集样本" - ) - parser.add_argument( - "--t_total", default=100000, type=int, required=False, help="计划训练多少步" - ) - parser.add_argument( - "--log_step", default=1, type=int, required=False, help="多少步汇报一次loss" - ) - parser.add_argument( - "--output_dir", default="model/", type=str, required=False, help="模型输出路径" - ) - args = parser.parse_args() - - val_examples = args.val_examples - vocab_path = args.vocab_path - max_length = args.max_length - batch_size = args.batch_size - epochs = args.epochs - output_path = args.output_dir - eval_interval = args.eval_interval - lr = args.lr - warmup_steps = args.warmup_steps - data_path = args.data_path - config_path = args.config_path - t_total = args.t_total - - checkpoint_callback = ModelCheckpoint( - dirpath=output_path, - verbose=True, - period=1, - save_top_k=1, - monitor="val_loss", - mode="min", - ) - learning_rate_callback = LearningRateMonitor() - trainer = pl.Trainer( - default_root_dir=output_path, - gradient_clip_val=1, - max_epochs=epochs, - gpus=args.device, - distributed_backend="dp", - val_check_interval=eval_interval, - callbacks=[learning_rate_callback, checkpoint_callback], - precision=32, - ) - net = Net( - batch_size, - epochs, - t_total=t_total, - config_path=config_path, - data_path=data_path, - valid_examples=val_examples, - vocab_path=vocab_path, - max_length=max_length, - warm_up_steps=warmup_steps, - lr=lr, - ) - # d = torch.load('output_old/best.ckpt', map_location=torch.device("cpu"))["state_dict"] - # d.pop('model.classifier.bias') - # d.pop('model.classifier.weight') - - # net.load_state_dict(d, strict=False) - trainer.fit(net) +import argparse +import json +from typing import List, Dict + +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from torch.optim import AdamW +from torch.utils.data import Dataset, DataLoader +from transformers import GPT2LMHeadModel, GPT2Config, get_linear_schedule_with_warmup, BertTokenizer + + +# 11846807 + + +class DS(Dataset): + def __init__(self, lines: List[str], tokenizer: BertTokenizer): + self.data = lines + self.tok = tokenizer + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + line = self.data[index] + line = self.tok.encode_plus( + line, + max_length=self.tok.model_max_length, + truncation=True, + padding="max_length", + return_tensors="pt", + ) + return line + + +class Net(pl.LightningModule): + def __init__( + self, + dataset: List[str], + batch_size, + epochs, + config_path="config/model_config.json", + valid_examples=100, + vocab_path="vocab/vocab.txt", + warm_up_steps=0, + lr=1e-4, + model: GPT2LMHeadModel = None, + tokenizer: BertTokenizer = None, + additional_special_tokens: Dict[str, str] = None, + ): + super(Net, self).__init__() + self.batch_size = batch_size + self.epochs = epochs + self.warm_up_steps = warm_up_steps + self.lr = lr + self.model_name = "bert_pretrained_model" + self.config = GPT2Config.from_json_file(config_path) + self.model = GPT2LMHeadModel(config=self.config) if model is None else model + self.tokenizer = BertTokenizer(vocab_file=vocab_path, + model_max_length=self.config.n_positions) if tokenizer is None else tokenizer + if additional_special_tokens: + self.tokenizer.add_special_tokens({'additional_special_tokens': list(additional_special_tokens.values())}) + self.data: List[str] = dataset + self.t_total = len(self.data) * epochs + self.dataset_train = DS(self.data[:-valid_examples], self.tokenizer) + self.dataset_valid = DS(self.data[-valid_examples:], self.tokenizer) + + def forward(self, input_ids, attention_mask): + input_ids = input_ids + attention_mask = attention_mask + r = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + return_dict=True, + ) + return r["loss"] + + def train_dataloader(self): + return DataLoader( + self.dataset_train, + batch_size=self.batch_size, + num_workers=8, + shuffle=True, + drop_last=True, + ) + + def val_dataloader(self): + return DataLoader( + self.dataset_valid, + batch_size=self.batch_size, + num_workers=8, + shuffle=False, + drop_last=True, + ) + + def configure_optimizers(self): + optimizer = AdamW(self.parameters(), lr=self.lr, weight_decay=0.001) + scheduler = get_linear_schedule_with_warmup( + optimizer, self.warm_up_steps, self.t_total + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return [optimizer], [scheduler] + + def training_step(self, batch, batch_nb): + loss = self.forward(batch["input_ids"], batch["attention_mask"]) + + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def validation_step(self, batch, batch_nb): + loss = self.forward(batch["input_ids"], batch["attention_mask"]) + return loss + + def validation_epoch_end(self, outputs): + avg_loss = torch.stack(outputs).mean() + self.log( + "val_loss", + avg_loss, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return {"val_loss": avg_loss} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--devices", default="", type=str, required=False, help="设置使用哪些显卡,用逗号分割" + ) + parser.add_argument( + "--accelerator", default="auto", type=str, required=False, + help='使用的计算类型,"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"', + choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"] + ) + parser.add_argument( + "--config_path", + default="config/model_config.json", + type=str, + required=False, + help="选择模型参数", + ) + parser.add_argument( + "--vocab_path", + default="vocab/vocab.txt", + type=str, + required=False, + help="选择词库", + ) + parser.add_argument( + "--data_path", + default="data/train.json", + type=str, + required=False, + help="原始训练语料", + ) + parser.add_argument("--epochs", default=5, type=int, required=False, help="训练循环") + parser.add_argument( + "--batch_size", default=8, type=int, required=False, help="训练 batch size" + ) + parser.add_argument("--lr", default=1.5e-4, type=float, required=False, help="学习率") + parser.add_argument( + "--warmup_steps", default=2000, type=int, required=False, help="warm up 步数" + ) + parser.add_argument( + "--max_length", default=1024, type=int, required=False, help="单条文本最长长度" + ) + parser.add_argument( + "--eval_interval", default=100, type=int, required=False, help="eval 步数" + ) + parser.add_argument( + "--val_examples", default=100, type=int, required=False, help="选择多少验证集样本" + ) + parser.add_argument( + "--log_step", default=1, type=int, required=False, help="多少步汇报一次loss" + ) + parser.add_argument( + "--output_dir", default="model/", type=str, required=False, help="模型输出路径" + ) + args = parser.parse_args() + + devices: str = args.devices + accelerator: str = args.accelerator + val_examples = args.val_examples + vocab_path = args.vocab_path + batch_size = args.batch_size + epochs = args.epochs + output_path = args.output_dir + eval_interval = args.eval_interval + lr = args.lr + warmup_steps = args.warmup_steps + data_path = args.data_path + config_path = args.config_path + + checkpoint_callback = ModelCheckpoint( + dirpath=output_path, + verbose=True, + save_top_k=1, + monitor="val_loss", + mode="min", + ) + learning_rate_callback = LearningRateMonitor() + trainer = pl.Trainer( + default_root_dir=output_path, + gradient_clip_val=1, + max_epochs=epochs, + devices=list(map(int, devices.split(','))) if devices else None, + accelerator=accelerator, + val_check_interval=eval_interval, + callbacks=[learning_rate_callback, checkpoint_callback], + # precision=32, + precision=16, + ) + + with open(data_path, encoding='utf-8') as f: + dataset: List[str] = [json.loads(line.strip()) for line in f] + + net = Net( + dataset, + batch_size, + epochs, + config_path=config_path, + valid_examples=val_examples, + vocab_path=vocab_path, + warm_up_steps=warmup_steps, + lr=lr, + ) + # d = torch.load('output_old/best.ckpt', map_location=torch.device("cpu"))["state_dict"] + # d.pop('model.classifier.bias') + # d.pop('model.classifier.weight') + + # net.load_state_dict(d, strict=False) + trainer.fit(net) + + net.model.save_pretrained(output_path) + net.tokenizer.save_pretrained(output_path) + + +if __name__ == "__main__": + main()