Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 32 additions & 38 deletions language/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,43 @@
from huggingface_hub import hf_hub_download


import os
import re
import torch
from transformers import AutoTokenizer, T5EncoderModel

import os
import re
import torch
from transformers import AutoTokenizer, T5EncoderModel

class T5Embedder:
available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa

def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}')

def __init__(
self,
device,
model_path,
*,
torch_dtype=None,
model_max_length=120
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
if t5_model_kwargs is None:
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}

self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
self.dir_or_name = dir_or_name
tokenizer_path, path = dir_or_name, dir_or_name
if local_cache:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
tokenizer_path, path = cache_dir, cache_dir
elif dir_or_name in self.available_models:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
for filename in [
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
]:
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
force_filename=filename, token=self.hf_token)
tokenizer_path, path = cache_dir, cache_dir
else:
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
for filename in [
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
]:
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
force_filename=filename, token=self.hf_token)
tokenizer_path = cache_dir

print(tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
self.model_max_length = model_max_length

tokenizer_path = os.path.abspath(model_path)
model_path = tokenizer_path

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True,
low_cpu_mem_usage=True,
torch_dtype=self.torch_dtype,
device_map={'shared': self.device, 'encoder': self.device}).eval()



def get_text_embeddings(self, texts):
texts = [self.text_preprocessing(text) for text in texts]

Expand Down Expand Up @@ -198,4 +192,4 @@ def clean_caption(self, caption):
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
caption = re.sub(r'^\.\S+$', '', caption)

return caption.strip()
return caption.strip()