File size: 6,438 Bytes
6a3ad5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Based on https://github.com/openai/CLIP/blob/main/clip/model.py

import onnxruntime
import numpy as np
from typing import List, Union
from PIL import Image

from clip.simple_tokenizer import SimpleTokenizer

def onnx_node_type_np_type(type):
    if type == "tensor(float)":
        return np.float32
    if type == "tensor(float16)":
        return np.float16
    if type == "tensor(int32)":
        return np.int32
    if type == "tensor(int64)":
        return np.int64
    raise NotImplementedError(f"Unsupported onnx type: {type}")

def ensure_input_type(input, type):
    np_type = onnx_node_type_np_type(type)
    if input.dtype == type:
        return input
    return input.astype(dtype=np_type)

class VisualModel:
    def __init__(self, path, providers=None):
        self.path = path
        print(f"Loading visual model: {path}")
        self.sess = onnxruntime.InferenceSession(path, providers=providers)
        self.input = self.sess.get_inputs()[0]
        self.output = self.sess.get_outputs()[0]
        
        if len(self.input.shape) != 4 or self.input.shape[2] != self.input.shape[3]:
            raise ValueError(f"unexpected shape {self.input.shape}")
        self.input_size = self.input.shape[2]
        print(f"Visual inference ready, input size {self.input_size}, type {self.input.type}")

    def encode(self, image_input):
        image_input = ensure_input_type(image_input, self.input.type)
        return self.sess.run([self.output.name], {self.input.name: image_input})[0]

    def fitted(self, size, w, h):
        short, long = (w, h) if w <= h else (h, w)
        new_short, new_long = size, int(size * long / short)
        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
        return [new_w, new_h]
        
    def resize_to(self, img, size):
        new_size = self.fitted(size, img.width, img.height)
        return img.resize(size=new_size, resample=Image.Resampling.BICUBIC)

    def center_crop(self, img, size):
        image_height = img.height
        image_width = img.width
        if size > image_width or size > image_height:
            padding_ltrb = [
                (size - image_width) // 2 if size > image_width else 0,
                (size - image_height) // 2 if size > image_height else 0,
                (size - image_width + 1) // 2 if size > image_width else 0,
                (size - image_height + 1) // 2 if size > image_height else 0,
            ]
            img = img.pad(img, padding_ltrb, fill=0)
            image_width = img.width
            image_height = img.height
            if size == image_width and size == image_height:
                return img
        top = int(round((image_height - size) / 2.0))
        left = int(round((image_width - size) / 2.0))
        return img.crop((left, top, left + size, top + size))

    def to_numpy(self, pic):
        mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
        img = np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
        if pic.mode == "1":
            img = 255 * img
        img = np.transpose(img, (2, 0, 1))
        img = img.astype(np.float32)
        img = np.divide(img, 255)
        return img

    def normalize(self, img):
        mean = np.array([0.48145466, 0.4578275, 0.40821073]).reshape((-1, 1, 1))
        std = np.array([0.26862954, 0.26130258, 0.27577711]).reshape((-1, 1, 1))
        return np.divide(np.subtract(img, mean), std)

    def preprocess(self, img):
        img = self.resize_to(img, self.input_size)
        img = self.center_crop(img, self.input_size)
        img = img.convert("RGB")
        img_np = self.to_numpy(img)
        img_np = self.normalize(img_np)
        return img_np

    def preprocess_images(self, images):
        preprocessed = []
        for img in images:
            if isinstance(img, str):
                img = Image.open(img)
            preprocessed.append(self.preprocess(img))
        return np.stack(preprocessed)

class TextualModel:
    def __init__(self, path, providers=None):
        self.path = path
        print(f"Loading textual model: {path}")
        self.sess = onnxruntime.InferenceSession(path, providers=providers)
        self.input = self.sess.get_inputs()[0]
        self.output = self.sess.get_outputs()[0]
        self.tokenizer = SimpleTokenizer()

        if len(self.input.shape) != 2 or self.input.shape[1] != 77:
            raise ValueError(f"unexpected shape {self.input.shape}")
        self.input_size = self.input.shape[1]
        print(f"Textual inference ready, input size {self.input_size}, type {self.input.type}")

    def encode(self, texts):
        return self.sess.run([self.output.name], {self.input.name: texts})[0]

    def tokenize(self, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> np.array:
        """
        Returns the tokenized representation of given input string(s)

        Parameters
        ----------
        texts : Union[str, List[str]]
            An input string or a list of input strings to tokenize

        context_length : int
            The context length to use; all CLIP models use 77 as the context length

        truncate: bool
            Whether to truncate the text in case its encoding is longer than the context length

        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
        We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
        """
        if isinstance(texts, str):
            texts = [texts]

        sot_token = self.tokenizer.encoder["<|startoftext|>"]
        eot_token = self.tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
        input_type = onnx_node_type_np_type(self.input.type)
        result = np.zeros(shape=(len(all_tokens), context_length), dtype=input_type)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                if truncate:
                    tokens = tokens[:context_length]
                    tokens[-1] = eot_token
                else:
                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
            result[i, :len(tokens)] = np.array(tokens)

        return result