import base64 from io import BytesIO import json import os from typing import Any, Dict, List, Optional, Tuple, Union from .custom_st_2 import OtherClass import requests import torch from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoImageProcessor from PIL import Image OtherClass() class Transformer(nn.Module): """Huggingface AutoModel to generate token embeddings. Loads the correct class, e.g. BERT / RoBERTa etc. Args: model_name_or_path: Huggingface models name (https://huggingface.co/models) max_seq_length: Truncate any inputs longer than max_seq_length model_args: Keyword arguments passed to the Huggingface Transformers model tokenizer_args: Keyword arguments passed to the Huggingface Transformers tokenizer config_args: Keyword arguments passed to the Huggingface Transformers config cache_dir: Cache dir for Huggingface Transformers to store/load models do_lower_case: If true, lowercases the input (independent if the model is cased or not) tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used """ def __init__( self, model_name_or_path: str, max_seq_length: Optional[int] = None, model_args: Optional[Dict[str, Any]] = None, tokenizer_args: Optional[Dict[str, Any]] = None, config_args: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, do_lower_case: bool = False, tokenizer_name_or_path: str = None, ) -> None: super(Transformer, self).__init__() self.config_keys = ["max_seq_length", "do_lower_case"] self.do_lower_case = do_lower_case if model_args is None: model_args = {} if tokenizer_args is None: tokenizer_args = {} if config_args is None: config_args = {} config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) self.jina_clip = AutoModel.from_pretrained( model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) if max_seq_length is not None and "model_max_length" not in tokenizer_args: tokenizer_args["model_max_length"] = max_seq_length self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **tokenizer_args, ) self.preprocessor = AutoImageProcessor.from_pretrained( tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **tokenizer_args, ) # No max_seq_length set. Try to infer from model if max_seq_length is None: if ( hasattr(self.jina_clip, "config") and hasattr(self.jina_clip.config, "max_position_embeddings") and hasattr(self.tokenizer, "model_max_length") ): max_seq_length = min(self.jina_clip.config.max_position_embeddings, self.tokenizer.model_max_length) self.max_seq_length = max_seq_length if tokenizer_name_or_path is not None: self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__ def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Returns token_embeddings, cls_token""" if "input_ids" in features: embedding = self.jina_clip.get_text_features(input_ids=features["input_ids"]) else: embedding = self.jina_clip.get_image_features(pixel_values=features["pixel_values"]) return {"sentence_embedding": embedding} def get_word_embedding_dimension(self) -> int: return self.config.text_config.embed_dim def decode_data_image(data_image_str): header, data = data_image_str.split(',', 1) image_data = base64.b64decode(data) return Image.open(BytesIO(image_data)) def tokenize( self, batch: Union[List[str]], padding: Union[str, bool] = True ) -> Dict[str, torch.Tensor]: """Tokenizes a text and maps tokens to token-ids""" images = [] texts = [] for sample in batch: if isinstance(sample, str): if sample.startswith('http'): response = requests.get(sample) images.append(Image.open(BytesIO(response.content)).convert('RGB')) elif sample.startswith('data:image/'): images.append(self.decode_data_image(sample).convert('RGB')) else: # TODO: Make sure that Image.open fails for non-image files try: images.append(Image.open(sample).convert('RGB')) except: texts.append(sample) elif isinstance(sample, Image.Image): images.append(sample.convert('RGB')) if images and texts: raise ValueError('Batch must contain either images or texts, not both') if texts: return self.tokenizer( texts, padding=padding, truncation="longest_first", return_tensors="pt", max_length=self.max_seq_length, ) elif images: return self.preprocessor(images) return {} def save(self, output_path: str, safe_serialization: bool = True) -> None: self.jina_clip.save_pretrained(output_path, safe_serialization=safe_serialization) self.tokenizer.save_pretrained(output_path) self.preprocessor.save_pretrained(output_path) @staticmethod def load(input_path: str) -> "Transformer": # Old classes used other config names than 'sentence_bert_config.json' for config_name in [ "sentence_bert_config.json", "sentence_roberta_config.json", "sentence_distilbert_config.json", "sentence_camembert_config.json", "sentence_albert_config.json", "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", ]: sbert_config_path = os.path.join(input_path, config_name) if os.path.exists(sbert_config_path): break with open(sbert_config_path) as fIn: config = json.load(fIn) # Don't allow configs to set trust_remote_code if "model_args" in config and "trust_remote_code" in config["model_args"]: config["model_args"].pop("trust_remote_code") if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]: config["tokenizer_args"].pop("trust_remote_code") if "config_args" in config and "trust_remote_code" in config["config_args"]: config["config_args"].pop("trust_remote_code") return Transformer(model_name_or_path=input_path, **config)