jiajunlong commited on
Commit
b1eaec8
1 Parent(s): a8863db
Files changed (1) hide show
  1. modeling_tinyllava_elm.py +223 -2
modeling_tinyllava_elm.py CHANGED
@@ -2,6 +2,12 @@ from dataclasses import dataclass
2
  from typing import List, Optional, Tuple, Union
3
  import ast
4
  import re
 
 
 
 
 
 
5
 
6
  import torch
7
  import torch.utils.checkpoint
@@ -12,11 +18,10 @@ from transformers import PreTrainedModel
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
13
  from transformers.generation.utils import GenerateOutput
14
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
 
15
 
16
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
17
 
18
- from transformers import AutoConfig, AutoModelForCausalLM
19
-
20
  # from tinyllava.utils.data_utils import get_value_from_kwargs
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
22
  WORKER_HEART_BEAT_INTERVAL = 15
@@ -47,6 +52,169 @@ import numpy as np
47
  from transformers import PretrainedConfig, AutoTokenizer
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def make_divisible(
51
  v: Union[float, int],
52
  divisor: Optional[int] = 8,
@@ -1686,6 +1854,59 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
1686
  position_ids = None
1687
 
1688
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1689
 
1690
 
1691
 
 
2
  from typing import List, Optional, Tuple, Union
3
  import ast
4
  import re
5
+ from enum import auto, Enum
6
+ import requests
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import base64
10
+ import time
11
 
12
  import torch
13
  import torch.utils.checkpoint
 
18
  from transformers.modeling_outputs import CausalLMOutputWithPast
19
  from transformers.generation.utils import GenerateOutput
20
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
21
+ from transformers import AutoConfig, AutoModelForCausalLM
22
 
23
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
24
 
 
 
25
  # from tinyllava.utils.data_utils import get_value_from_kwargs
26
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
27
  WORKER_HEART_BEAT_INTERVAL = 15
 
52
  from transformers import PretrainedConfig, AutoTokenizer
53
 
54
 
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # Model Constants
59
+ IGNORE_INDEX = -100
60
+ IMAGE_TOKEN_INDEX = -200
61
+ DEFAULT_IMAGE_TOKEN = "<image>"
62
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
63
+ DEFAULT_IM_START_TOKEN = "<im_start>"
64
+ DEFAULT_IM_END_TOKEN = "<im_end>"
65
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
66
+
67
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
68
+ WORKER_HEART_BEAT_INTERVAL = 15
69
+ LOGDIR = "."
70
+
71
+
72
+ class SeparatorStyle(Enum):
73
+ """Different separator style."""
74
+ SINGLE = auto()
75
+ TWO = auto()
76
+ MPT = auto()
77
+ PLAIN = auto()
78
+ LLAMA_2 = auto()
79
+ TINY_LLAMA = auto()
80
+ QWEN_2 = auto()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class Conversation:
85
+ """A class that keeps all conversation history."""
86
+ system: str
87
+ roles: List[str]
88
+ messages: List[List[str]]
89
+ offset: int
90
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
91
+ sep: str = "###"
92
+ sep2: str = None
93
+ version: str = "Unknown"
94
+
95
+ skip_next: bool = False
96
+
97
+ def get_prompt(self):
98
+ messages = self.messages
99
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
100
+ messages = self.messages.copy()
101
+ init_role, init_msg = messages[0].copy()
102
+ init_msg = init_msg[0].replace("<image>", "").strip()
103
+ if 'mmtag' in self.version:
104
+ messages[0] = (init_role, init_msg)
105
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
106
+ messages.insert(1, (self.roles[1], "Received."))
107
+ else:
108
+ messages[0] = (init_role, "<image>\n" + init_msg)
109
+
110
+ if self.sep_style == SeparatorStyle.TWO:
111
+ seps = [self.sep, self.sep2]
112
+ ret = self.system + seps[0]
113
+ for i, (role, message) in enumerate(messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ ret += role + ": " + message + seps[i % 2]
118
+ else:
119
+ ret += role + ":"
120
+ else:
121
+ raise ValueError(f"Invalid style: {self.sep_style}")
122
+
123
+ return ret
124
+
125
+ def append_message(self, role, message):
126
+ self.messages.append([role, message])
127
+
128
+ def copy(self):
129
+ return Conversation(
130
+ system=self.system,
131
+ roles=self.roles,
132
+ messages=[[x, y] for x, y in self.messages],
133
+ offset=self.offset,
134
+ sep_style=self.sep_style,
135
+ sep=self.sep,
136
+ sep2=self.sep2,
137
+ version=self.version)
138
+
139
+
140
+
141
+
142
+ conv_phi_v0 = Conversation(
143
+ system="A chat between a curious user and an artificial intelligence assistant. "
144
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
145
+ roles=("USER", "ASSISTANT"),
146
+ version="phi",
147
+ messages=(),
148
+ offset=0,
149
+ sep_style=SeparatorStyle.TWO,
150
+ sep=" ",
151
+ sep2="<|endoftext|>",
152
+ )
153
+
154
+
155
+ def load_image_from_base64(image):
156
+ return Image.open(BytesIO(base64.b64decode(image)))
157
+
158
+
159
+ def expand2square(pil_img, background_color):
160
+ width, height = pil_img.size
161
+ if width == height:
162
+ return pil_img
163
+ elif width > height:
164
+ result = Image.new(pil_img.mode, (width, width), background_color)
165
+ result.paste(pil_img, (0, (width - height) // 2))
166
+ return result
167
+ else:
168
+ result = Image.new(pil_img.mode, (height, height), background_color)
169
+ result.paste(pil_img, ((height - width) // 2, 0))
170
+ return result
171
+
172
+
173
+ def process_images(images, image_processor, model_cfg):
174
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
175
+ new_images = []
176
+ if image_aspect_ratio == 'pad':
177
+ for image in images:
178
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
179
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
180
+ new_images.append(image)
181
+ else:
182
+ return image_processor(images, return_tensors='pt')['pixel_values']
183
+ if all(x.shape == new_images[0].shape for x in new_images):
184
+ new_images = torch.stack(new_images, dim=0)
185
+ return new_images
186
+
187
+
188
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
189
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
190
+
191
+ def insert_separator(X, sep):
192
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
193
+
194
+ input_ids = []
195
+ offset = 0
196
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
197
+ offset = 1
198
+ input_ids.append(prompt_chunks[0][0])
199
+
200
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
201
+ input_ids.extend(x[offset:])
202
+
203
+ if return_tensors is not None:
204
+ if return_tensors == 'pt':
205
+ return torch.tensor(input_ids, dtype=torch.long)
206
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
207
+ return input_ids
208
+
209
+ def load_image(image_file):
210
+ if image_file.startswith("http") or image_file.startswith("https"):
211
+ response = requests.get(image_file)
212
+ image = Image.open(BytesIO(response.content)).convert("RGB")
213
+ else:
214
+ image = Image.open(image_file).convert("RGB")
215
+ return image
216
+
217
+
218
  def make_divisible(
219
  v: Union[float, int],
220
  divisor: Optional[int] = 8,
 
1854
  position_ids = None
1855
 
1856
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1857
+
1858
+ def chat(
1859
+ self,
1860
+ prompt: str,
1861
+ tokenizer = None,
1862
+ image: str = None,
1863
+ max_new_tokens: int = 512,
1864
+ num_beams = 1,
1865
+ top_p=None,
1866
+ temperature=0
1867
+ ):
1868
+ image_processor = self.vision_tower._image_processor
1869
+
1870
+ if image is not None:
1871
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
1872
+ conv = conv_phi_v0.copy()
1873
+ conv.append_message(conv.roles[0], prompt)
1874
+ conv.append_message(conv.roles[1], None)
1875
+ prompt = conv.get_prompt()
1876
+ if image is not None:
1877
+ image = load_image(image)
1878
+ image_tensor = process_images(image, image_processor, self.config).to(self.device)
1879
+
1880
+ input_ids = (
1881
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
1882
+ .unsqueeze(0).to(self.device)
1883
+ )
1884
+ # Generate
1885
+ stime = time.time()
1886
+
1887
+ with torch.inference_mode():
1888
+ output_ids = self.generate(
1889
+ input_ids,
1890
+ images=image_tensor,
1891
+ do_sample=True if temperature > 0 else False,
1892
+ temperature=temperature,
1893
+ top_p=top_p,
1894
+ num_beams=num_beams,
1895
+ pad_token_id=tokenizer.pad_token_id,
1896
+ max_new_tokens=max_new_tokens,
1897
+ use_cache=True,
1898
+ # stopping_criteria=[stopping_criteria],
1899
+ )
1900
+
1901
+ # print('inference over')
1902
+ generation_time = time.time() - stime
1903
+ outputs = tokenizer.batch_decode(
1904
+ output_ids, skip_special_tokens=True
1905
+ )[0]
1906
+
1907
+ outputs = outputs.strip()
1908
+
1909
+ return outputs, generation_time
1910
 
1911
 
1912