bwang0911 commited on
Commit
a71961f
1 Parent(s): a10a808

Create custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +196 -0
custom_st.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from torch import nn
11
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
12
+
13
+
14
+ class Transformer(nn.Module):
15
+ """Huggingface AutoModel to generate token embeddings.
16
+ Loads the correct class, e.g. BERT / RoBERTa etc.
17
+
18
+ Args:
19
+ model_name_or_path: Huggingface models name
20
+ (https://huggingface.co/models)
21
+ max_seq_length: Truncate any inputs longer than max_seq_length
22
+ model_args: Keyword arguments passed to the Huggingface
23
+ Transformers model
24
+ tokenizer_args: Keyword arguments passed to the Huggingface
25
+ Transformers tokenizer
26
+ config_args: Keyword arguments passed to the Huggingface
27
+ Transformers config
28
+ cache_dir: Cache dir for Huggingface Transformers to store/load
29
+ models
30
+ do_lower_case: If true, lowercases the input (independent if the
31
+ model is cased or not)
32
+ tokenizer_name_or_path: Name or path of the tokenizer. When
33
+ None, then model_name_or_path is used
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_name_or_path: str,
39
+ max_seq_length: Optional[int] = None,
40
+ model_args: Optional[Dict[str, Any]] = None,
41
+ tokenizer_args: Optional[Dict[str, Any]] = None,
42
+ config_args: Optional[Dict[str, Any]] = None,
43
+ cache_dir: Optional[str] = None,
44
+ do_lower_case: bool = False,
45
+ tokenizer_name_or_path: str = None,
46
+ ) -> None:
47
+ super(Transformer, self).__init__()
48
+ self.config_keys = ["max_seq_length", "do_lower_case"]
49
+ self.do_lower_case = do_lower_case
50
+ if model_args is None:
51
+ model_args = {}
52
+ if tokenizer_args is None:
53
+ tokenizer_args = {}
54
+ if config_args is None:
55
+ config_args = {}
56
+
57
+ config = AutoConfig.from_pretrained(
58
+ model_name_or_path, **config_args, cache_dir=cache_dir
59
+ )
60
+ self.jina_clip = AutoModel.from_pretrained(
61
+ model_name_or_path, config=config, cache_dir=cache_dir, **model_args
62
+ )
63
+
64
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
65
+ tokenizer_args["model_max_length"] = max_seq_length
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ (
68
+ tokenizer_name_or_path
69
+ if tokenizer_name_or_path is not None
70
+ else model_name_or_path
71
+ ),
72
+ cache_dir=cache_dir,
73
+ **tokenizer_args,
74
+ )
75
+ self.preprocessor = AutoImageProcessor.from_pretrained(
76
+ (
77
+ tokenizer_name_or_path
78
+ if tokenizer_name_or_path is not None
79
+ else model_name_or_path
80
+ ),
81
+ cache_dir=cache_dir,
82
+ **tokenizer_args,
83
+ )
84
+
85
+ # No max_seq_length set. Try to infer from model
86
+ if max_seq_length is None:
87
+ if (
88
+ hasattr(self.jina_clip, "config")
89
+ and hasattr(self.jina_clip.config, "max_position_embeddings")
90
+ and hasattr(self.tokenizer, "model_max_length")
91
+ ):
92
+ max_seq_length = min(
93
+ self.jina_clip.config.max_position_embeddings,
94
+ self.tokenizer.model_max_length,
95
+ )
96
+
97
+ self.max_seq_length = max_seq_length
98
+
99
+ if tokenizer_name_or_path is not None:
100
+ self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
101
+
102
+ def forward(
103
+ self, features: Dict[str, torch.Tensor]
104
+ ) -> Dict[str, torch.Tensor]:
105
+ """Returns token_embeddings, cls_token"""
106
+ if "input_ids" in features:
107
+ embedding = self.jina_clip.get_text_features(
108
+ input_ids=features["input_ids"]
109
+ )
110
+ else:
111
+ embedding = self.jina_clip.get_image_features(
112
+ pixel_values=features["pixel_values"]
113
+ )
114
+ return {"sentence_embedding": embedding}
115
+
116
+ def get_word_embedding_dimension(self) -> int:
117
+ return self.config.text_config.embed_dim
118
+
119
+ def decode_data_image(data_image_str):
120
+ header, data = data_image_str.split(',', 1)
121
+ image_data = base64.b64decode(data)
122
+ return Image.open(BytesIO(image_data))
123
+
124
+ def tokenize(
125
+ self, batch: Union[List[str]], padding: Union[str, bool] = True
126
+ ) -> Dict[str, torch.Tensor]:
127
+ """Tokenizes a text and maps tokens to token-ids"""
128
+ images = []
129
+ texts = []
130
+ for sample in batch:
131
+ if isinstance(sample, str):
132
+ if sample.startswith('http'):
133
+ response = requests.get(sample)
134
+ images.append(Image.open(BytesIO(response.content)).convert('RGB'))
135
+ elif sample.startswith('data:image/'):
136
+ images.append(self.decode_data_image(sample).convert('RGB'))
137
+ else:
138
+ # TODO: Make sure that Image.open fails for non-image files
139
+ try:
140
+ images.append(Image.open(sample).convert('RGB'))
141
+ except:
142
+ texts.append(sample)
143
+ elif isinstance(sample, Image.Image):
144
+ images.append(sample.convert('RGB'))
145
+
146
+ if images and texts:
147
+ raise ValueError('Batch must contain either images or texts, not both')
148
+
149
+ if texts:
150
+ return self.tokenizer(
151
+ texts,
152
+ padding=padding,
153
+ truncation="longest_first",
154
+ return_tensors="pt",
155
+ max_length=self.max_seq_length,
156
+ )
157
+ elif images:
158
+ return self.preprocessor(images)
159
+ return {}
160
+
161
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
162
+ self.jina_clip.save_pretrained(
163
+ output_path, safe_serialization=safe_serialization
164
+ )
165
+ self.tokenizer.save_pretrained(output_path)
166
+ self.preprocessor.save_pretrained(output_path)
167
+
168
+ @staticmethod
169
+ def load(input_path: str) -> "Transformer":
170
+ # Old classes used other config names than 'sentence_bert_config.json'
171
+ for config_name in [
172
+ "sentence_bert_config.json",
173
+ "sentence_roberta_config.json",
174
+ "sentence_distilbert_config.json",
175
+ "sentence_camembert_config.json",
176
+ "sentence_albert_config.json",
177
+ "sentence_xlm-roberta_config.json",
178
+ "sentence_xlnet_config.json",
179
+ ]:
180
+ sbert_config_path = os.path.join(input_path, config_name)
181
+ if os.path.exists(sbert_config_path):
182
+ break
183
+
184
+ with open(sbert_config_path) as fIn:
185
+ config = json.load(fIn)
186
+ # Don't allow configs to set trust_remote_code
187
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
188
+ config["model_args"].pop("trust_remote_code")
189
+ if (
190
+ "tokenizer_args" in config
191
+ and "trust_remote_code" in config["tokenizer_args"]
192
+ ):
193
+ config["tokenizer_args"].pop("trust_remote_code")
194
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
195
+ config["config_args"].pop("trust_remote_code")
196
+ return Transformer(model_name_or_path=input_path, **config)