Spaces:
Runtime error
Runtime error
Upload t5.py
Browse files
t5.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import html
|
6 |
+
import urllib.parse as ul
|
7 |
+
|
8 |
+
import ftfy
|
9 |
+
import torch
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
|
15 |
+
class T5Embedder:
|
16 |
+
available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
|
17 |
+
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
|
18 |
+
|
19 |
+
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
20 |
+
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
21 |
+
self.device = torch.device(device)
|
22 |
+
self.torch_dtype = torch_dtype or torch.bfloat16
|
23 |
+
if t5_model_kwargs is None:
|
24 |
+
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
25 |
+
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
|
26 |
+
|
27 |
+
self.use_text_preprocessing = use_text_preprocessing
|
28 |
+
self.hf_token = hf_token
|
29 |
+
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
|
30 |
+
self.dir_or_name = dir_or_name
|
31 |
+
tokenizer_path, path = dir_or_name, dir_or_name
|
32 |
+
if local_cache:
|
33 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
34 |
+
tokenizer_path, path = cache_dir, cache_dir
|
35 |
+
elif dir_or_name in self.available_models:
|
36 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
37 |
+
for filename in [
|
38 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
39 |
+
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
|
40 |
+
]:
|
41 |
+
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
|
42 |
+
force_filename=filename, token=self.hf_token)
|
43 |
+
tokenizer_path, path = cache_dir, cache_dir
|
44 |
+
else:
|
45 |
+
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
|
46 |
+
for filename in [
|
47 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
48 |
+
]:
|
49 |
+
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
|
50 |
+
force_filename=filename, token=self.hf_token)
|
51 |
+
tokenizer_path = cache_dir
|
52 |
+
|
53 |
+
print(tokenizer_path)
|
54 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
55 |
+
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
56 |
+
self.model_max_length = model_max_length
|
57 |
+
|
58 |
+
def get_text_embeddings(self, texts):
|
59 |
+
texts = [self.text_preprocessing(text) for text in texts]
|
60 |
+
|
61 |
+
text_tokens_and_mask = self.tokenizer(
|
62 |
+
texts,
|
63 |
+
max_length=self.model_max_length,
|
64 |
+
padding='max_length',
|
65 |
+
truncation=True,
|
66 |
+
return_attention_mask=True,
|
67 |
+
add_special_tokens=True,
|
68 |
+
return_tensors='pt'
|
69 |
+
)
|
70 |
+
|
71 |
+
text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
|
72 |
+
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
text_encoder_embs = self.model(
|
76 |
+
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
|
77 |
+
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
|
78 |
+
)['last_hidden_state'].detach()
|
79 |
+
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
|
80 |
+
|
81 |
+
def text_preprocessing(self, text):
|
82 |
+
if self.use_text_preprocessing:
|
83 |
+
# The exact text cleaning as was in the training stage:
|
84 |
+
text = self.clean_caption(text)
|
85 |
+
text = self.clean_caption(text)
|
86 |
+
return text
|
87 |
+
else:
|
88 |
+
return text.lower().strip()
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def basic_clean(text):
|
92 |
+
text = ftfy.fix_text(text)
|
93 |
+
text = html.unescape(html.unescape(text))
|
94 |
+
return text.strip()
|
95 |
+
|
96 |
+
def clean_caption(self, caption):
|
97 |
+
caption = str(caption)
|
98 |
+
caption = ul.unquote_plus(caption)
|
99 |
+
caption = caption.strip().lower()
|
100 |
+
caption = re.sub('<person>', 'person', caption)
|
101 |
+
# urls:
|
102 |
+
caption = re.sub(
|
103 |
+
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
104 |
+
'', caption) # regex for urls
|
105 |
+
caption = re.sub(
|
106 |
+
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
107 |
+
'', caption) # regex for urls
|
108 |
+
# html:
|
109 |
+
caption = BeautifulSoup(caption, features='html.parser').text
|
110 |
+
|
111 |
+
# @<nickname>
|
112 |
+
caption = re.sub(r'@[\w\d]+\b', '', caption)
|
113 |
+
|
114 |
+
# 31C0—31EF CJK Strokes
|
115 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
116 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
117 |
+
# 3300—33FF CJK Compatibility
|
118 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
119 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
120 |
+
# 4E00—9FFF CJK Unified Ideographs
|
121 |
+
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
|
122 |
+
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
|
123 |
+
caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
|
124 |
+
caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
|
125 |
+
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
|
126 |
+
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
|
127 |
+
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
|
128 |
+
#######################################################
|
129 |
+
|
130 |
+
# все виды тире / all types of dash --> "-"
|
131 |
+
caption = re.sub(
|
132 |
+
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
|
133 |
+
'-', caption)
|
134 |
+
|
135 |
+
# кавычки к одному стандарту
|
136 |
+
caption = re.sub(r'[`´«»“”¨]', '"', caption)
|
137 |
+
caption = re.sub(r'[‘’]', "'", caption)
|
138 |
+
|
139 |
+
# "
|
140 |
+
caption = re.sub(r'"?', '', caption)
|
141 |
+
# &
|
142 |
+
caption = re.sub(r'&', '', caption)
|
143 |
+
|
144 |
+
# ip adresses:
|
145 |
+
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
|
146 |
+
|
147 |
+
# article ids:
|
148 |
+
caption = re.sub(r'\d:\d\d\s+$', '', caption)
|
149 |
+
|
150 |
+
# \n
|
151 |
+
caption = re.sub(r'\\n', ' ', caption)
|
152 |
+
|
153 |
+
# "#123"
|
154 |
+
caption = re.sub(r'#\d{1,3}\b', '', caption)
|
155 |
+
# "#12345.."
|
156 |
+
caption = re.sub(r'#\d{5,}\b', '', caption)
|
157 |
+
# "123456.."
|
158 |
+
caption = re.sub(r'\b\d{6,}\b', '', caption)
|
159 |
+
# filenames:
|
160 |
+
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
|
161 |
+
|
162 |
+
#
|
163 |
+
caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
|
164 |
+
caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
|
165 |
+
|
166 |
+
caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
167 |
+
caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
|
168 |
+
|
169 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
170 |
+
regex2 = re.compile(r'(?:\-|\_)')
|
171 |
+
if len(re.findall(regex2, caption)) > 3:
|
172 |
+
caption = re.sub(regex2, ' ', caption)
|
173 |
+
|
174 |
+
caption = self.basic_clean(caption)
|
175 |
+
|
176 |
+
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
|
177 |
+
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
|
178 |
+
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
|
179 |
+
|
180 |
+
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
|
181 |
+
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
|
182 |
+
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
|
183 |
+
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
|
184 |
+
caption = re.sub(r'\bpage\s+\d+\b', '', caption)
|
185 |
+
|
186 |
+
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
|
187 |
+
|
188 |
+
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
|
189 |
+
|
190 |
+
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
|
191 |
+
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
|
192 |
+
caption = re.sub(r'\s+', ' ', caption)
|
193 |
+
|
194 |
+
caption.strip()
|
195 |
+
|
196 |
+
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
|
197 |
+
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
|
198 |
+
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
|
199 |
+
caption = re.sub(r'^\.\S+$', '', caption)
|
200 |
+
|
201 |
+
return caption.strip()
|