Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 49 additions & 49 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
import torch.nn.functional as F
import torch.nn.functional as functional
import os
import argparse
from tqdm import trange
from transformers import GPT2LMHeadModel, GPT2Config, BertTokenizer


def is_word(word):
for item in list(word):
if item not in "qwertyuiopasdfghjklzxcvbnm":
return False
return True
def is_word(word: str):
return word.isalpha()


def _is_chinese_char(char):
Expand All @@ -22,17 +19,17 @@ def _is_chinese_char(char):
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
# like the all the other languages.
cp = ord(char)
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
(0x4E00 <= cp <= 0x9FFF)
or (0x3400 <= cp <= 0x4DBF) #
or (0x20000 <= cp <= 0x2A6DF) #
or (0x2A700 <= cp <= 0x2B73F) #
or (0x2B740 <= cp <= 0x2B81F) #
or (0x2B820 <= cp <= 0x2CEAF) #
or (0xF900 <= cp <= 0xFAFF)
or (0x2F800 <= cp <= 0x2FA1F) #
): #
return True

Expand All @@ -41,15 +38,18 @@ def _is_chinese_char(char):

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
:param top_k:
:param top_p:
:param filter_value:
:param logits:
logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with the highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Nucleus filtering is described in Holtzman et al. (https://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert (
logits.dim() == 1
logits.dim() == 1
) # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
Expand All @@ -59,7 +59,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")

if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = torch.cumsum(functional.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
Expand All @@ -73,43 +73,43 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")


def sample_sequence(
model,
context,
length,
n_ctx,
tokenizer,
temperature=1.0,
top_k=30,
top_p=0.0,
repitition_penalty=1.0,
device="cpu",
model,
context,
length,
n_ctx,
tokenizer,
temperature=1.0,
top_k=30,
top_p=0.0,
repetition_penalty=1.0,
device="cpu",
):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0)
generated = context
with torch.no_grad():
for _ in trange(length):
inputs = {"input_ids": generated[0][-(n_ctx - 1) :].unsqueeze(0)}
inputs = {"input_ids": generated[0][-(n_ctx - 1):].unsqueeze(0)}
outputs = model(
**inputs
) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
) # Note: we could also use 'past' with GPT-2/Transformer-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :]
for id in set(generated):
next_token_logits[id] /= repitition_penalty
for idx in set(generated):
next_token_logits[idx] /= repetition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[tokenizer.convert_tokens_to_ids("[UNK]")] = -float("Inf")
filtered_logits = top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p
)
next_token = torch.multinomial(
F.softmax(filtered_logits, dim=-1), num_samples=1
functional.softmax(filtered_logits, dim=-1), num_samples=1
)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
return generated.tolist()[0]


def fast_sample_sequence(
model, context, length, temperature=1.0, top_k=30, top_p=0.0, device="cpu"
model, context, length, temperature=1.0, top_k=30, top_p=0.0, device="cpu"
):
inputs = torch.LongTensor(context).view(1, -1).to(device)
if len(context) > 1:
Expand All @@ -120,7 +120,7 @@ def fast_sample_sequence(
prev = inputs
generate = [] + context
with torch.no_grad():
for i in trange(length):
for _ in trange(length):
output = model(prev, past=past)
output, past = output[:2]
output = output[-1].squeeze(0) / temperature
Expand All @@ -142,13 +142,13 @@ def main():
"--batch_size", default=1, type=int, required=False, help="生成的batch size"
)
parser.add_argument(
"--nsamples", default=10, type=int, required=False, help="生成几个样本"
"--n_samples", default=10, type=int, required=False, help="生成几个样本"
)
parser.add_argument(
"--temperature", default=1, type=float, required=False, help="生成温度"
)
parser.add_argument("--topk", default=8, type=int, required=False, help="最高几选一")
parser.add_argument("--topp", default=0, type=float, required=False, help="最高积累概率")
parser.add_argument("--top_k", default=8, type=int, required=False, help="最高几选一")
parser.add_argument("--top_p", default=0, type=float, required=False, help="最高积累概率")
parser.add_argument(
"--model_config",
default="config/model_config.json",
Expand Down Expand Up @@ -188,11 +188,11 @@ def main():
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
length = args.length
n_ctx = args.n_ctx
batch_size = args.batch_size
nsamples = args.nsamples
# batch_size = args.batch_size
n_samples = args.n_samples
temperature = args.temperature
topk = args.topk
topp = args.topp
top_k = args.top_k
top_p = args.top_p
repetition_penalty = args.repetition_penalty

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -209,7 +209,7 @@ def main():
model.to(device)
model.eval()

for i in range(nsamples):
for i in range(n_samples):
raw_text = args.prefix
encoded = tokenizer.encode_plus(raw_text)["input_ids"][:-1]
out = sample_sequence(
Expand All @@ -219,9 +219,9 @@ def main():
n_ctx=n_ctx,
tokenizer=tokenizer,
temperature=temperature,
top_k=topk,
top_p=topp,
repitition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device,
)
print(tokenizer.decode(out))
Expand Down
45 changes: 45 additions & 0 deletions gpt2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from abc import ABC

from transformers import GPT2Model
from transformers.models.gpt2.modeling_gpt2 import (
BaseModelOutputWithPastAndCrossAttentions,
torch,
Optional, Tuple, Union
)


class GPT2ModelWithAdditionalHiddenStates(GPT2Model, ABC):
def __init__(self, config):
super().__init__(config)

def forward(self, additional_hidden_states: Optional[torch.FloatTensor] = None,
**kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
:param additional_hidden_states: shape = [batch, seq_length, embed_size]
"""
if 'input_ids' in kwargs:
input_ids = kwargs.pop('input_ids')
inputs_embeds: torch.FloatTensor = self.wte(input_ids)
else:
inputs_embeds: torch.FloatTensor = kwargs.pop('inputs_embeds')

assert inputs_embeds.size() == additional_hidden_states.size()

kwargs['inputs_embeds'] = inputs_embeds + additional_hidden_states

return super().forward(**kwargs)


class GPT2ModelWithContext(GPT2ModelWithAdditionalHiddenStates, ABC):
def __init__(self, config):
super().__init__(config)

def context_to_embedding(self, context: dict) -> torch.FloatTensor:
"""
:return: additional_hidden_states, shape = [batch, seq_length, embed_size]
"""
raise NotImplementedError

def forward(self, context: dict = None, **kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
additional_hidden_states: torch.FloatTensor = self.context_to_embedding(context)
return super().forward(additional_hidden_states, **kwargs)
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytorch-lightning==1.2.2
transformers==4.2.1
torch==1.8.0
tqdm==4.56.0
# python~=3.10.8
pytorch-lightning~=1.8.4.post0
torch~=1.12.0
transformers~=4.25.1
tqdm~=4.65.0
2 changes: 1 addition & 1 deletion scripts/generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ python generate.py \
--tokenizer_path cache/vocab_small.txt \
--model_path model/final_model \
--prefix "[CLS][MASK]" \
--topp 1 \
--top_p 1 \
--temperature 1.0
77 changes: 77 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
from transformers import BertTokenizer, GPT2Config, GPT2LMHeadModel, TextGenerationPipeline
from collections import OrderedDict
import torch
from typing import List


class TestMain(unittest.TestCase):
def setUp(self) -> None:
# self.model_config_path: str = 'config/model_config_test.json'
# self.model_config_path: str = 'config/model_config_small.json'
self.model_config_path: str = 'config/model_config.json'
self.vocab_path: str = 'vocab/vocab.txt'
self.config: GPT2Config = GPT2Config.from_json_file(self.model_config_path)

def test_train(self):
import sys
from train import main
args = [
'--data_path', 'data/train.txt',
'--batch_size', '2',
# '--devices', '0'
'--config_path', self.model_config_path,
# '--config_path', 'config/model_config_test.json',
'--epochs', '2'
]
sys.argv.extend(args)
main()
self.assertTrue(True)

@classmethod
def pipeline(cls, model: GPT2LMHeadModel, tokenizer: BertTokenizer, text: str) -> str:
pad_token_id = tokenizer('[PAD]')['input_ids'][1]
pipeline = TextGenerationPipeline(model, tokenizer)
result = pipeline(text, max_length=100, pad_token_id=pad_token_id, do_sample=True)
return result

@classmethod
def generate(cls, model: GPT2LMHeadModel, tokenizer: BertTokenizer, text: str) -> str:
pad_token_id = tokenizer('[PAD]')['input_ids'][1]
input_ids = tokenizer('[CLS]' + text, return_tensors='pt', padding=False, add_special_tokens=False)['input_ids']
output_ids: torch.Tensor = model.generate(input_ids, max_length=100, pad_token_id=pad_token_id, do_sample=True)
output_tokens: List[str] = tokenizer.convert_ids_to_tokens(output_ids[0])
output_text: str = ''.join(filter(lambda x: x not in ['[SEP]', '[PAD]', '[CLS]'], output_tokens))
return output_text

def load_from_checkpoint(self, checkpoint_path: str) -> GPT2LMHeadModel:
model: GPT2LMHeadModel = GPT2LMHeadModel(config=self.config)
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
raw_state_dict: OrderedDict = checkpoint["state_dict"]
state_dict: OrderedDict = OrderedDict({k.replace('model.', ''): v for k, v in raw_state_dict.items()})
model.load_state_dict(state_dict)
return model

def test_generate(self):
# load from checkpoint
# checkpoint_path: str = 'model/epoch=1-step=862.ckpt'
# model = self.load_from_checkpoint(checkpoint_path)
# tokenizer: BertTokenizer = BertTokenizer(vocab_file=self.vocab_path)

# load using transformers api
model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained('model/gpt2-chinese-cluecorpussmall')
tokenizer: BertTokenizer = BertTokenizer.from_pretrained('model/gpt2-chinese-cluecorpussmall')

model.eval()

text: str = '我叫'

output_text: str = self.generate(model, tokenizer, text)
# output_text: str = self.pipeline(model, tokenizer, text)
print(output_text)

self.assertTrue(True)


if __name__ == '__main__':
unittest.main()
Loading