diff --git a/app.py b/app.py index 67aaef2f9ba2bbc829b1aa4e7f21772038699c13..d8452a9ed9e4a6815ffa5e7765f95e4a68149c0b 100644 --- a/app.py +++ b/app.py @@ -261,7 +261,7 @@ if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--server_name", default="0.0.0.0", type=str) argparser.add_argument("--port", default="6123", type=str) - argparser.add_argument("--model_path", default="", type=str) + argparser.add_argument("--model_path", default="lmms-lab/llava-next-interleave-qwen-7b", type=str) # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m") argparser.add_argument("--model-base", type=str, default=None) argparser.add_argument("--num-gpus", type=int, default=1) diff --git a/llava/__init__.py b/llava/__init__.py deleted file mode 100644 index 4d1f016db1028101d45ba7d68cb3f0bcb558c2bb..0000000000000000000000000000000000000000 --- a/llava/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import LlavaLlamaForCausalLM diff --git a/llava/__pycache__/__init__.cpython-310.pyc b/llava/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3aed37c2596f03824f9bfb7342019e1e6e6e527d..0000000000000000000000000000000000000000 Binary files a/llava/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/llava/__pycache__/constants.cpython-310.pyc b/llava/__pycache__/constants.cpython-310.pyc deleted file mode 100644 index 5d141b2cbf5dd5882af1f1e90b6d76cf0323d533..0000000000000000000000000000000000000000 Binary files a/llava/__pycache__/constants.cpython-310.pyc and /dev/null differ diff --git a/llava/__pycache__/conversation.cpython-310.pyc b/llava/__pycache__/conversation.cpython-310.pyc deleted file mode 100644 index d892221ed358f95aee8459a96857c959125fbda1..0000000000000000000000000000000000000000 Binary files a/llava/__pycache__/conversation.cpython-310.pyc and /dev/null differ diff --git a/llava/__pycache__/mm_utils.cpython-310.pyc b/llava/__pycache__/mm_utils.cpython-310.pyc deleted file mode 100644 index f51b41270ff21d99b9400f1d73fb07da02e8e76a..0000000000000000000000000000000000000000 Binary files a/llava/__pycache__/mm_utils.cpython-310.pyc and /dev/null differ diff --git a/llava/__pycache__/utils.cpython-310.pyc b/llava/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index f6361701d4f2c4734c12d4458facc27bb4bc2483..0000000000000000000000000000000000000000 Binary files a/llava/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/llava/constants.py b/llava/constants.py deleted file mode 100644 index be8cf0204969a6c973f442b383d8e425d684e826..0000000000000000000000000000000000000000 --- a/llava/constants.py +++ /dev/null @@ -1,12 +0,0 @@ -CONTROLLER_HEART_BEAT_EXPIRATION = 30 -WORKER_HEART_BEAT_INTERVAL = 15 - -LOGDIR = "." - -# Model Constants -IGNORE_INDEX = -100 -IMAGE_TOKEN_INDEX = -200 -DEFAULT_IMAGE_TOKEN = "" -DEFAULT_IMAGE_PATCH_TOKEN = "" -DEFAULT_IM_START_TOKEN = "" -DEFAULT_IM_END_TOKEN = "" diff --git a/llava/conversation.py b/llava/conversation.py deleted file mode 100644 index 716137bcbb294e31a86d18443a5e14bdec2a7269..0000000000000000000000000000000000000000 --- a/llava/conversation.py +++ /dev/null @@ -1,554 +0,0 @@ -import dataclasses -from enum import auto, Enum -from typing import List, Any, Dict, Union, Tuple -import re -import base64 -from io import BytesIO -from PIL import Image -from transformers import AutoTokenizer - - -class SeparatorStyle(Enum): - """Different separator style.""" - - SINGLE = auto() - TWO = auto() - MPT = auto() - PLAIN = auto() - CHATML = auto() - LLAMA_2 = auto() - LLAMA_3 = auto() - QWEN = auto() - GEMMA = auto() - - -@dataclasses.dataclass -class Conversation: - """A class that keeps all conversation history.""" - - system: str - roles: List[str] - messages: List[List[str]] - offset: int - sep_style: SeparatorStyle = SeparatorStyle.SINGLE - sep: str = "###" - sep2: str = None - version: str = "Unknown" - - tokenizer_id: str = "" - tokenizer: Any = None - # Stop criteria (the default one is EOS token) - stop_str: Union[str, List[str]] = None - # Stops generation if meeting any token in this list - stop_token_ids: List[int] = None - - skip_next: bool = False - - def get_prompt(self): - messages = self.messages - if len(messages) > 0 and type(messages[0][1]) is tuple: - messages = self.messages.copy() - init_role, init_msg = messages[0].copy() - init_msg = init_msg[0] - if "mmtag" in self.version: - init_msg = init_msg.replace("", "").strip() - messages[0] = (init_role, init_msg) - messages.insert(0, (self.roles[0], "")) - messages.insert(1, (self.roles[1], "Received.")) - elif not init_msg.startswith(""): - init_msg = init_msg.replace("", "").strip() - messages[0] = (init_role, "\n" + init_msg) - else: - messages[0] = (init_role, init_msg) - - if self.sep_style == SeparatorStyle.SINGLE: - ret = self.system + self.sep - for role, message in messages: - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + ": " + message + self.sep - else: - ret += role + ":" - - elif self.sep_style == SeparatorStyle.TWO: - seps = [self.sep, self.sep2] - ret = self.system + seps[0] - for i, (role, message) in enumerate(messages): - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - - elif self.sep_style == SeparatorStyle.CHATML: - ret = "" if self.system == "" else self.system + self.sep + "\n" - for role, message in messages: - if message: - if type(message) is tuple: - message, images = message - message = "" * len(images) + message - ret += role + "\n" + message + self.sep + "\n" - else: - ret += role + "\n" - return ret - - elif self.sep_style == SeparatorStyle.LLAMA_3: - chat_template_messages = [{"role": "system", "content": self.system}] - for role, message in messages: - if message: - if type(message) is tuple: - message, images = message - message = "" * len(images) + message - chat_template_messages.append({"role": role, "content": message}) - - # print(chat_template_messages) - return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) - # ret = "" if self.system == "" else self.system + self.sep + "\n" - # for role, message in messages: - # if message: - # if type(message) is tuple: - # message, images = message - # message = "" * len(images) + message - # ret += role + "\n" + message + self.sep + "\n" - # else: - # ret += role + "\n" - # return ret - - elif self.sep_style == SeparatorStyle.MPT: - ret = self.system + self.sep - for role, message in messages: - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + message + self.sep - else: - ret += role - - elif self.sep_style == SeparatorStyle.GEMMA: - ret = "" - for i, (role, message) in enumerate(messages): - assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..." - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + message + self.sep - else: - ret += role - - elif self.sep_style == SeparatorStyle.LLAMA_2: - wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg - wrap_inst = lambda msg: f"[INST] {msg} [/INST]" - ret = "" - - for i, (role, message) in enumerate(messages): - if i == 0: - assert message, "first message should not be none" - assert role == self.roles[0], "first message should come from user" - if message: - if type(message) is tuple: - message, _, _ = message - if i == 0: - message = wrap_sys(self.system) + message - if i % 2 == 0: - message = wrap_inst(message) - ret += self.sep + message - else: - ret += " " + message + " " + self.sep2 - else: - ret += "" - ret = ret.lstrip(self.sep) - - elif self.sep_style == SeparatorStyle.PLAIN: - seps = [self.sep, self.sep2] - ret = self.system - for i, (role, message) in enumerate(messages): - if message: - if type(message) is tuple: - message, _, _ = message - ret += message + seps[i % 2] - else: - ret += "" - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - return ret - - def append_message(self, role, message): - self.messages.append([role, message]) - - def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"): - if image_process_mode == "Pad": - - def expand2square(pil_img, background_color=(122, 116, 104)): - width, height = pil_img.size - if width == height: - return pil_img - elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) - result.paste(pil_img, (0, (width - height) // 2)) - return result - else: - result = Image.new(pil_img.mode, (height, height), background_color) - result.paste(pil_img, ((height - width) // 2, 0)) - return result - - image = expand2square(image) - elif image_process_mode in ["Default", "Crop"]: - pass - elif image_process_mode == "Resize": - image = image.resize((336, 336)) - else: - raise ValueError(f"Invalid image_process_mode: {image_process_mode}") - - max_hw, min_hw = max(image.size), min(image.size) - aspect_ratio = max_hw / min_hw - max_len, min_len = 672, 448 - shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) - longest_edge = int(shortest_edge * aspect_ratio) - W, H = image.size - if H > W: - H, W = longest_edge, shortest_edge - else: - H, W = shortest_edge, longest_edge - image = image.resize((W, H)) - if return_pil: - return image - else: - buffered = BytesIO() - image.save(buffered, format=image_format) - img_b64_str = base64.b64encode(buffered.getvalue()).decode() - return img_b64_str - - def get_images(self, return_pil=False): - images = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): - if i % 2 == 0: - if type(msg) is tuple: - msg, image, image_process_mode = msg - if type(image) != list: - image = [image] - for img in image: - img = self.process_image(img, image_process_mode, return_pil=return_pil) - images.append(img) - return images - - def to_gradio_chatbot(self): - ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): - if i % 2 == 0: - if type(msg) is tuple: - msg, image, image_process_mode = msg - if type(image) != list: - image = [image] - if len(image) == 1: - msg = "\n" + msg.replace("", "").strip() - else: - msg = re.sub(r"()\n(?=)", r"\1 ", msg) - for img in image: - img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG") - img_str = f'' - msg = msg.replace("", img_str, 1).strip() - if len(msg) > 0: - ret.append([msg, None]) - else: - ret.append([msg, None]) - else: - ret[-1][-1] = msg - return ret - - def copy(self): - return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) - - def dict(self): - if len(self.get_images()) > 0: - return { - "system": self.system, - "roles": self.roles, - "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], - "offset": self.offset, - "sep": self.sep, - "sep2": self.sep2, - } - return { - "system": self.system, - "roles": self.roles, - "messages": self.messages, - "offset": self.offset, - "sep": self.sep, - "sep2": self.sep2, - } - - -conv_vicuna_v0 = Conversation( - system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", - roles=("Human", "Assistant"), - messages=[ - ["Human", "What are the key differences between renewable and non-renewable energy sources?"], - [ - "Assistant", - "Renewable energy sources are those that can be replenished naturally in a relatively " - "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " - "Non-renewable energy sources, on the other hand, are finite and will eventually be " - "depleted, such as coal, oil, and natural gas. Here are some key differences between " - "renewable and non-renewable energy sources:\n" - "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " - "energy sources are finite and will eventually run out.\n" - "2. Environmental impact: Renewable energy sources have a much lower environmental impact " - "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " - "and other negative effects.\n" - "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " - "have lower operational costs than non-renewable sources.\n" - "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " - "locations than non-renewable sources.\n" - "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " - "situations and needs, while non-renewable sources are more rigid and inflexible.\n" - "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " - "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", - ], - ], - offset=2, - sep_style=SeparatorStyle.SINGLE, - sep="###", -) - -conv_vicuna_v1 = Conversation( - system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", - roles=("USER", "ASSISTANT"), - version="v1", - messages=[], - offset=0, - sep_style=SeparatorStyle.TWO, - sep=" ", - sep2="", -) - -conv_llama_2 = Conversation( - system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", - roles=("USER", "ASSISTANT"), - version="llama_v2", - messages=[], - offset=0, - sep_style=SeparatorStyle.LLAMA_2, - sep="", - sep2="", -) - -conv_llava_llama_2 = Conversation( - system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", - roles=("USER", "ASSISTANT"), - version="llama_v2", - messages=[], - offset=0, - sep_style=SeparatorStyle.LLAMA_2, - sep="", - sep2="", -) - -try: - llama3_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") -except Exception as e: - print("Error loading llama3 tokenizer") - print(e) - -# conv_llava_llama_3 = Conversation( -# system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", -# roles=("<|start_header_id|>user", "<|start_header_id|>assistant"), -# version="llama_v3", -# messages=[], -# offset=0, -# sep_style=SeparatorStyle.LLAMA_3, -# tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct", -# tokenizer=llama3_tokenizer, -# stop_token_ids=[128009], -# ) - -conv_mistral_instruct = Conversation( - system="", - roles=("USER", "ASSISTANT"), - version="llama_v2", - messages=[], - offset=0, - sep_style=SeparatorStyle.LLAMA_2, - sep="", - sep2="", -) - -conv_llava_llama_2_simple = Conversation( - system="Answer the questions about the visual content that the user provides.", - roles=("USER", "ASSISTANT"), - version="llama_v2", - messages=[], - offset=0, - sep_style=SeparatorStyle.LLAMA_2, - sep="", - sep2="", -) - -conv_llava_llama_2_mmtag = Conversation( - system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: visual content.", - roles=("USER", "ASSISTANT"), - version="llama_v2_mmtag", - messages=[], - offset=0, - sep_style=SeparatorStyle.LLAMA_2, - sep="", - sep2="", -) - -conv_mpt = Conversation( - system="""<|im_start|>system -A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", - roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), - version="mpt", - messages=[], - offset=0, - sep_style=SeparatorStyle.MPT, - sep="<|im_end|>", -) - -conv_qwen = Conversation( - system="""<|im_start|>system -You are a helpful assistant.""", - roles=("<|im_start|>user", "<|im_start|>assistant"), - version="qwen", - messages=[], - offset=0, - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", -) - -conv_gemma_instruct = Conversation(system="", roles=("user\n", "model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="\n") - -conv_llava_plain = Conversation( - system="", - roles=("", ""), - messages=[], - offset=0, - sep_style=SeparatorStyle.PLAIN, - sep="\n", -) - -conv_llava_v0 = Conversation( - system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", - roles=("Human", "Assistant"), - messages=[], - offset=0, - sep_style=SeparatorStyle.SINGLE, - sep="###", -) - -conv_llava_v0_mmtag = Conversation( - system="A chat between a curious user and an artificial intelligence assistant. " - "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "The visual content will be provided with the following format: visual content.", - roles=("Human", "Assistant"), - messages=[], - offset=0, - sep_style=SeparatorStyle.SINGLE, - sep="###", - version="v0_mmtag", -) - -conv_llava_v1 = Conversation( - system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", - roles=("USER", "ASSISTANT"), - version="v1", - messages=[], - offset=0, - sep_style=SeparatorStyle.TWO, - sep=" ", - sep2="", -) - -conv_llava_v1_mmtag = Conversation( - system="A chat between a curious user and an artificial intelligence assistant. " - "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "The visual content will be provided with the following format: visual content.", - roles=("USER", "ASSISTANT"), - messages=[], - offset=0, - sep_style=SeparatorStyle.TWO, - sep=" ", - sep2="", - version="v1_mmtag", -) - -conv_mistral_orca = Conversation( - system="""<|im_start|>system -You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""", - roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), - version="mpt", - messages=[], - offset=0, - sep_style=SeparatorStyle.MPT, - sep="<|im_end|>", -) - -conv_mistral_zephyr = Conversation( - system="""<|system|> -You are a helpful AI assistant.""", - roles=("<|user|>\n", "<|assistant|>\n"), - version="mpt", - messages=[], - offset=0, - sep_style=SeparatorStyle.MPT, - sep="", -) - -conv_mistral_direct = Conversation( - system="""<|im_start|>system -Answer the questions.""", - roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), - version="mpt", - messages=[], - offset=0, - sep_style=SeparatorStyle.MPT, - sep="<|im_end|>", -) - -conv_chatml_direct = Conversation( - system="""<|im_start|>system -Answer the questions.""", - roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), - version="mpt", - messages=[], - offset=0, - sep_style=SeparatorStyle.MPT, - sep="<|im_end|>", -) - -default_conversation = conv_vicuna_v0 -conv_templates = { - "default": conv_vicuna_v0, - "v0": conv_vicuna_v0, - "v1": conv_vicuna_v1, - "vicuna_v1": conv_vicuna_v1, - "llama_2": conv_llama_2, - "mistral_instruct": conv_mistral_instruct, - "mistral_orca": conv_mistral_orca, - "mistral_zephyr": conv_mistral_zephyr, - "mistral_direct": conv_mistral_direct, - "plain": conv_llava_plain, - "v0_plain": conv_llava_plain, - "chatml_direct": conv_chatml_direct, - "llava_v0": conv_llava_v0, - "llava_v0_mmtag": conv_llava_v0_mmtag, - "llava_v1": conv_llava_v1, - "llava_v1_mmtag": conv_llava_v1_mmtag, - "llava_llama_2": conv_llava_llama_2, - # "llava_llama_3": conv_llava_llama_3, - "llava_llama_2_simple": conv_llava_llama_2_simple, - "llava_llama_2_mmtag": conv_llava_llama_2_mmtag, - "llava_mistral_instruct": conv_mistral_instruct, - "mpt": conv_mpt, - "qwen_1_5": conv_qwen, - "gemma_instruct": conv_gemma_instruct, -} - - -if __name__ == "__main__": - print(default_conversation.get_prompt()) diff --git a/llava/eval/evaluate_interleave.py b/llava/eval/evaluate_interleave.py deleted file mode 100644 index b00d32f2d3f732614384d648cbda85b69af32466..0000000000000000000000000000000000000000 --- a/llava/eval/evaluate_interleave.py +++ /dev/null @@ -1,339 +0,0 @@ -import re -from rouge import Rouge -import argparse -import os -import json -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity - - -spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"] -image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"] -visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"] -visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"] -text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"] -multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"] - -puzzle = ["RAVEN"] -nlrv2 = ["NLVR2_Mantis"] -qbench = ["QBench"] - -class Eval: - def __init__(self): - self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") - self.commaStrip = re.compile("(\d)(\,)(\d)") - self.punct = [ - ";", - r"/", - "[", - "]", - '"', - "{", - "}", - "(", - ")", - "=", - "+", - "\\", - "_", - "-", - ">", - "<", - "@", - "`", - ",", - "?", - "!", - ] - - def processPunctuation(self, inText): - outText = inText - for p in self.punct: - if (p + " " in inText or " " + p in inText) or ( - re.search(self.commaStrip, inText) != None - ): - outText = outText.replace(p, "") - else: - outText = outText.replace(p, " ") - outText = self.periodStrip.sub("", outText, re.UNICODE) - return outText - - def process(self, answer): - answer = answer.replace("\n", " ") - answer = answer.replace("\t", " ") - answer = answer.strip() - answer = self.processPunctuation(answer) - answer = answer.strip('\'') - answer = answer.strip('\"') - answer = answer.strip(')') - answer = answer.strip('(') - answer = answer.strip().lower() - return answer - - def evaluate_rouge(self,preds): - rouge = Rouge() - acc = {'f': []} - eval_list = [] - for i, res in enumerate(preds): - sample_id = res['sample_id'] - # print(sample_id) - gt_ans = self.process(res["gt_response"]) - pred_ans = self.process(res["pred_response"]) - # assert gt_ans != '' - - if gt_ans == '': - continue - - if pred_ans == '': - s = 0 - else: - if len(pred_ans) > 512: - pred_ans = pred_ans[0: 512] - s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f'] - acc['f'].append(s) - eval_list.append({'id':str(sample_id),'score':str(round(s,3))}) - results = {'Rouge-L f': np.mean(acc['f'])} - return results,eval_list - - - def judge_multi_choice(self,sample): - sample_id = sample['sample_id'] - gt_ans = sample["gt_response"] - pred_ans = sample["pred_response"] - - if ":" in pred_ans: - a_list = pred_ans.split(":") - a_list = [a.strip() for a in a_list ] - for a in a_list: - if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: - pred_ans = a - - if pred_ans == gt_ans: - return 1 - else: - return 0 - - def process_sample(self,sample): - sample["gt_response"] = self.process(sample["gt_response"]) - sample["pred_response"] = self.process(sample["pred_response"]) - - def evaluate_multichoice(self, preditions): - correct = 0 - eval_list = [] - for i, sample in enumerate(preditions): - self.process_sample(sample) - score = self.judge_multi_choice(sample) - sample_id = sample['sample_id'] - sample['result'] = score - eval_list.append({'id':str(sample_id),'score':str(score)}) - correct+=score - return {'Accuracy':correct/len(preditions)},eval_list - - def evaluate_multi_choice_image(self,preditions): - correct = 0 - eval_list = [] - for i,sample in enumerate(preditions): - gt_ans = self.process(sample["gt_response"]) - pred_ans = self.process(sample["pred_response"]) - sample_id = sample['sample_id'] - - if ":" in pred_ans: - a_list = pred_ans.split(":") - a_list = [a.strip() for a in a_list ] - for a in a_list: - if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: - pred_ans = a - - if gt_ans == pred_ans: - score = 1 - else: - score = 0 - sample_id = sample['sample_id'] - sample['result'] = score - eval_list.append({'id':str(sample_id),'score':str(score)}) - correct+=score - return {'Accuracy':correct/len(preditions)},eval_list - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--result-dir', type=str, required=True) - - args = parser.parse_args() - - result_file = os.path.join(args.result_dir, "result.jsonl") - - if not os.path.exists(result_file): - print('No prediction file found') - exit(0) - with open(result_file, 'r') as f: - preds_all = [json.loads(line) for line in f] - - preds_all_dict = dict() - for pred in preds_all: - if pred["dataset"] not in preds_all_dict: - preds_all_dict[pred["dataset"]] = list() - preds_all_dict[pred["dataset"]].append(pred) - - image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"] - E = Eval() - - eval_result_list = dict() - eval_result_list_detail = dict() - - for dataset in preds_all_dict: - - preds = preds_all_dict[dataset] - question_type = preds[0]["question_type"] - - if question_type == 'open-ended': - eval_result, eval_list = E.evaluate_rouge(preds) - - elif question_type == 'multi-choice' or dataset == 'nlrv2': - if dataset in image_choice_dataset_list: - eval_result, eval_list = E.evaluate_multi_choice_image(preds) - else: - eval_result, eval_list = E.evaluate_multichoice(preds) - - else: - eval_result = 'Dataset not supported' - print('Dataset not supported') - exit(0) - - print(dataset, end = ': ') - print(eval_result) - - eval_result_list[dataset] = eval_result - eval_result_list_detail[dataset] = eval_list - - os.makedirs(args.result_dir, exist_ok=True) - with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f: - json.dump(eval_result_list, f, indent=4) - - with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f: - json.dump(eval_result_list_detail, f, indent=4) - - - eval_cat_list = dict() - print() - - # spot_the_diff - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in spot_the_diff: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["spot_the_diff"] = score - print("spot_the_diff", end = ': ') - print('{:.2f}'.format(100 * score)) - - # image_edit_instruct - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in image_edit_instruct: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["image_edit_instruct"] = score - print("image_edit_instruct", end = ': ') - print('{:.2f}'.format(100 * score)) - - # visual_story_telling - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in visual_story_telling: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["visual_story_telling"] = score - print("visual_story_telling", end = ': ') - print('{:.2f}'.format(100 * score)) - - # visual_cloze - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in visual_cloze: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["visual_cloze"] = score - print("visual_cloze", end = ': ') - print('{:.2f}'.format(100 * score)) - - # text_rich_vqa - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in text_rich_vqa: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["text_rich_vqa"] = score - print("text_rich_vqa", end = ': ') - print('{:.2f}'.format(100 * score)) - - # multi_image_vqa - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in multi_image_vqa: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["multi_image_vqa"] = score - print("multi_image_vqa", end = ': ') - print('{:.2f}'.format(100 * score)) - - # puzzle - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in puzzle: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["puzzle"] = score - print("puzzle", end = ': ') - print('{:.2f}'.format(100 * score)) - - # nlrv2 - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in nlrv2: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["nlrv2"] = score - print("nlrv2", end = ': ') - print('{:.2f}'.format(100 * score)) - - # qbench - score = 0 - count = 0 - for dataset in eval_result_list: - if dataset in qbench: - count += 1 - score += list(eval_result_list[dataset].values())[0] - if count > 0: - score /= count - eval_cat_list["qbench"] = score - print("qbench", end = ': ') - print('{:.2f}'.format(100 * score)) - - with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f: - json.dump(eval_cat_list, f, indent=4) \ No newline at end of file diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py deleted file mode 100644 index 2ebceedafe23eaf90e51e0971fbdfcae45555838..0000000000000000000000000000000000000000 --- a/llava/eval/model_vqa.py +++ /dev/null @@ -1,240 +0,0 @@ -import argparse -import torch -import os -import json -from tqdm import tqdm -import shortuuid - -from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -from llava.conversation import conv_templates, SeparatorStyle -from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init -from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria - -from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX -from typing import Dict, Optional, Sequence, List -import transformers -import re - -from PIL import Image -import math - - -def split_list(lst, n): - """Split a list into n (roughly) equal-sized chunks""" - chunk_size = math.ceil(len(lst) / n) # integer division - return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] - - -def get_chunk(lst, n, k): - chunks = split_list(lst, n) - return chunks[k] - -def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: - roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} - - im_start, im_end = tokenizer.additional_special_tokens_ids - nl_tokens = tokenizer("\n").input_ids - _system = tokenizer("system").input_ids + nl_tokens - _user = tokenizer("user").input_ids + nl_tokens - _assistant = tokenizer("assistant").input_ids + nl_tokens - - # Apply prompt templates - input_ids, targets = [], [] - - source = sources - if roles[source[0]["from"]] != roles["human"]: - source = source[1:] - - input_id, target = [], [] - system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens - input_id += system - target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens - assert len(input_id) == len(target) - for j, sentence in enumerate(source): - role = roles[sentence["from"]] - if has_image and sentence["value"] is not None and "" in sentence["value"]: - num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) - texts = sentence["value"].split('') - _input_id = tokenizer(role).input_ids + nl_tokens - for i,text in enumerate(texts): - _input_id += tokenizer(text).input_ids - if iuser": - _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens - elif role == "<|im_start|>assistant": - _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens - else: - raise NotImplementedError - target += _target - - input_ids.append(input_id) - targets.append(target) - input_ids = torch.tensor(input_ids, dtype=torch.long) - targets = torch.tensor(targets, dtype=torch.long) - return input_ids - -def eval_model(args): - - # Model - disable_torch_init() - model_path = os.path.expanduser(args.model_path) - model_name = get_model_name_from_path(model_path) - tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) - - # Data - with open(os.path.expanduser(args.question_file)) as f: - questions = json.load(f) - questions = get_chunk(questions, args.num_chunks, args.chunk_idx) - answers_file = os.path.expanduser(args.answers_file) - os.makedirs(os.path.dirname(answers_file), exist_ok=True) - ans_file = open(answers_file, "w") - - for line in tqdm(questions): - idx = line["sample_id"] - question_type = line["metadata"]["question_type"] - dataset_name = line["metadata"]["dataset"] - gt = line["conversations"][1]["value"] - - image_files = line["image"] - qs = line["conversations"][0]["value"] - cur_prompt = args.extra_prompt + qs - - args.conv_mode = "qwen_1_5" - - conv = conv_templates[args.conv_mode].copy() - conv.append_message(conv.roles[0], qs) - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda() - img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX) - - image_tensors = [] - for image_file in image_files: - image = Image.open(os.path.join(args.image_folder, image_file)) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] - image_tensors.append(image_tensor.half().cuda()) - # image_tensors = torch.cat(image_tensors, dim=0) - - stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 - keywords = [stop_str] - stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) - - with torch.inference_mode(): - output_ids = model.generate( - input_ids, - images=image_tensors, - do_sample=True if args.temperature > 0 else False, - temperature=args.temperature, - top_p=args.top_p, - num_beams=args.num_beams, - # no_repeat_ngram_size=3, - max_new_tokens=1024, - use_cache=True) - - - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] - outputs = outputs.strip() - if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] - outputs = outputs.strip() - - ans_id = shortuuid.uuid() - ans_file.write(json.dumps({ - "dataset": dataset_name, - "sample_id": idx, - "prompt": cur_prompt, - "pred_response": outputs, - "gt_response": gt, - "shortuuid": ans_id, - "model_id": model_name, - "question_type": question_type, - }) + "\n") - ans_file.flush() - - if len(line["conversations"]) > 2: - - for i in range(2, len(line["conversations"]), 2): - input_ids = torch.cat((input_ids, output_ids), dim=1) - - gt = line["conversations"][i + 1]["value"] - qs = line["conversations"][i]["value"] - cur_prompt = args.extra_prompt + qs - - args.conv_mode = "qwen_1_5" - - conv = conv_templates[args.conv_mode].copy() - conv.append_message(conv.roles[0], qs) - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda() - input_ids = torch.cat((input_ids, input_ids_new), dim=1) - img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX) - - stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 - keywords = [stop_str] - stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) - - with torch.inference_mode(): - output_ids = model.generate( - input_ids, - images=image_tensors, - do_sample=True if args.temperature > 0 else False, - temperature=args.temperature, - top_p=args.top_p, - num_beams=args.num_beams, - # no_repeat_ngram_size=3, - max_new_tokens=1024, - use_cache=True) - - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] - outputs = outputs.strip() - if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] - outputs = outputs.strip() - - ans_id = shortuuid.uuid() - ans_file.write(json.dumps({ - "dataset": dataset_name, - "sample_id": idx, - "prompt": cur_prompt, - "pred_response": outputs, - "gt_response": gt, - "shortuuid": ans_id, - "model_id": model_name, - "question_type": question_type, - }) + "\n") - ans_file.flush() - - - ans_file.close() - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model-path", type=str, default="facebook/opt-350m") - parser.add_argument("--model-base", type=str, default=None) - parser.add_argument("--image-folder", type=str, default="") - parser.add_argument("--extra-prompt", type=str, default="") - parser.add_argument("--question-file", type=str, default="tables/question.jsonl") - parser.add_argument("--answers-file", type=str, default="answer.jsonl") - parser.add_argument("--conv-mode", type=str, default="llava_v1") - parser.add_argument("--num-chunks", type=int, default=1) - parser.add_argument("--chunk-idx", type=int, default=0) - parser.add_argument("--temperature", type=float, default=0.2) - parser.add_argument("--top_p", type=float, default=None) - parser.add_argument("--num_beams", type=int, default=1) - parser.add_argument("--test_size", type=int, default=10000000) - args = parser.parse_args() - - eval_model(args) \ No newline at end of file diff --git a/llava/mm_utils.py b/llava/mm_utils.py deleted file mode 100644 index 3e5c8b0c277869bcd280ccf5f7e2509941eb5105..0000000000000000000000000000000000000000 --- a/llava/mm_utils.py +++ /dev/null @@ -1,381 +0,0 @@ -from PIL import Image -from io import BytesIO -import base64 -import math -import ast - -import torch -from transformers import StoppingCriteria -from llava.constants import IMAGE_TOKEN_INDEX - - -def resize_and_center_crop(image, shortest_edge_length): - # Calculate new dimensions and resize - aspect_ratio = float(image.width) / float(image.height) - if aspect_ratio > 1: - new_width = int(shortest_edge_length * aspect_ratio) - new_height = shortest_edge_length - else: - new_width = shortest_edge_length - new_height = int(shortest_edge_length / aspect_ratio) - resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) - - # Calculate the position and perform the center crop - left = (new_width - shortest_edge_length) / 2 - top = (new_height - shortest_edge_length) / 2 - right = (new_width + shortest_edge_length) / 2 - bottom = (new_height + shortest_edge_length) / 2 - cropped_image = resized_image.crop((left, top, right, bottom)) - - return cropped_image - - -def auto_pad_images(image, grid_params): - assert isinstance(image, Image.Image), "Input should be a Pillow Image" - assert len(grid_params) > 0, "Grid parameters should not be empty" - - # Step 1: Calculate and find the closest aspect ratio - input_width, input_height = image.size - input_aspect_ratio = input_width / input_height - candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params] - closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0])) - - candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3] - - target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1)) - - resize_width, resize_height = target_resolution - if input_width > input_height: - resize_height = int(resize_width / input_aspect_ratio) - else: - resize_width = int(resize_height * input_aspect_ratio) - resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS) - - # Step 5: Pad the resized image if necessary to match the target resolution - pad_width = target_resolution[0] - resize_width - pad_height = target_resolution[1] - resize_height - padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0)) - padded_image.paste(resized_image, (pad_width // 2, pad_height // 2)) - - return padded_image - - -def extract_patches(image, patch_size, overlap_ratio): - assert isinstance(image, Image.Image), "Input should be a Pillow Image" - assert patch_size > 0, "Patch size should be greater than 0" - assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1" - - W, H = image.size - patches = [] - - stride = int(patch_size * (1 - overlap_ratio)) - - num_patches_y = (H - patch_size) // stride + 1 - num_patches_x = (W - patch_size) // stride + 1 - - y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2 - x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2 - - for y in range(y_start, y_start + num_patches_y * stride, stride): - for x in range(x_start, x_start + num_patches_x * stride, stride): - patch = image.crop((x, y, x + patch_size, y + patch_size)) - patches.append(patch) - - return patches - - -def process_highres_image_crop_split(image, data_args, processor=None): - crop_resolution = data_args.image_crop_resolution - split_resolution = data_args.image_split_resolution - if processor is None: - processor = data_args.image_processor - image_crop = resize_and_center_crop(image, crop_resolution) - image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0) - image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] - return torch.stack(image_patches, dim=0) - - -def process_highres_image(image, processor, grid_pinpoints): - grid_params = [int(x) for x in grid_pinpoints.split(",")] - width_height = max(image.size) - fit_grid_params = [x for x in grid_params if x >= width_height] - if len(fit_grid_params) == 0: - select_size = max(grid_params) - else: - select_size = min(fit_grid_params) - # FIXME: always select the 448 - select_size = max(grid_params) - image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) - - # FIXME: this seems to be a bug that it always resizes instead of padding - image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"])) - image_padded = image_padded.resize((select_size, select_size)) - image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0) - image_patches = [image_original_resize] + image_patches - image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] - return torch.stack(image_patches, dim=0) - - -def select_best_resolution(original_size, possible_resolutions): - """ - Selects the best resolution from a list of possible resolutions based on the original size. - - Args: - original_size (tuple): The original size of the image in the format (width, height). - possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - - Returns: - tuple: The best fit resolution in the format (width, height). - """ - original_width, original_height = original_size - best_fit = None - max_effective_resolution = 0 - min_wasted_resolution = float("inf") - - for width, height in possible_resolutions: - # Calculate the downscaled size to keep the aspect ratio - scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) - - # Calculate effective and wasted resolutions - effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) - wasted_resolution = (width * height) - effective_resolution - - if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): - max_effective_resolution = effective_resolution - min_wasted_resolution = wasted_resolution - best_fit = (width, height) - - return best_fit - - -def resize_and_pad_image(image, target_resolution): - """ - Resize and pad an image to a target resolution while maintaining aspect ratio. - - Args: - image (PIL.Image.Image): The input image. - target_resolution (tuple): The target resolution (width, height) of the image. - - Returns: - PIL.Image.Image: The resized and padded image. - """ - original_width, original_height = image.size - target_width, target_height = target_resolution - - # Determine which dimension (width or height) to fill - scale_w = target_width / original_width - scale_h = target_height / original_height - - if scale_w < scale_h: - # Width will be filled completely - new_width = target_width - new_height = min(math.ceil(original_height * scale_w), target_height) - else: - # Height will be filled completely - new_height = target_height - new_width = min(math.ceil(original_width * scale_h), target_width) - - # Resize the image - resized_image = image.resize((new_width, new_height)) - - # Create a new image with the target size and paste the resized image onto it - new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) - paste_x = (target_width - new_width) // 2 - paste_y = (target_height - new_height) // 2 - new_image.paste(resized_image, (paste_x, paste_y)) - - return new_image - - -def divide_to_patches(image, patch_size): - """ - Divides an image into patches of a specified size. - - Args: - image (PIL.Image.Image): The input image. - patch_size (int): The size of each patch. - - Returns: - list: A list of PIL.Image.Image objects representing the patches. - """ - patches = [] - width, height = image.size - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - box = (j, i, j + patch_size, i + patch_size) - patch = image.crop(box) - patches.append(patch) - - return patches - - -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (tuple): The size of the input image in the format (width, height). - grid_pinpoints (str): A string representation of a list of possible resolutions. - patch_size (int): The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if isinstance(grid_pinpoints, str): - assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" - grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(") - grid_pinpoints = [[int(x) * patch_size for x in item.split(",")] for item in grid_pinpoints] - - if type(grid_pinpoints) is list: - possible_resolutions = grid_pinpoints - else: - possible_resolutions = ast.literal_eval(grid_pinpoints) - width, height = select_best_resolution(image_size, possible_resolutions) - return width // patch_size, height // patch_size - - -def process_anyres_image(image, processor, grid_pinpoints): - """ - Process an image with variable resolutions. - - Args: - image (PIL.Image.Image): The input image to be processed. - processor: The image processor object. - grid_pinpoints (str): A string representation of a list of possible resolutions. - - Returns: - torch.Tensor: A tensor containing the processed image patches. - """ - # Convert grid_pinpoints from string to list - if isinstance(grid_pinpoints, str): - vis_encoder_size = processor.size[0] - assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]" - grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(") - grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints] - - if type(grid_pinpoints) is list: - possible_resolutions = grid_pinpoints - else: - possible_resolutions = ast.literal_eval(grid_pinpoints) - best_resolution = select_best_resolution(image.size, possible_resolutions) - image_padded = resize_and_pad_image(image, best_resolution) - - patches = divide_to_patches(image_padded, processor.crop_size["height"]) - - # FIXME: this seems to be a bug that it resizes instead of pad. - # but to keep it consistent with previous, i will keep it as it is - # TODO: uncomment below to ablate with the padding - if isinstance(processor.size, dict): - shortest_edge = processor.size["shortest_edge"] - else: - shortest_edge = min(processor.size) - image_original_resize = image.resize((shortest_edge, shortest_edge)) - # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) - # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) - - image_patches = [image_original_resize] + patches - image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] - return torch.stack(image_patches, dim=0) - - -def load_image_from_base64(image): - return Image.open(BytesIO(base64.b64decode(image))) - - -def expand2square(pil_img, background_color): - width, height = pil_img.size - if width == height: - return pil_img - elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) - result.paste(pil_img, (0, (width - height) // 2)) - return result - else: - result = Image.new(pil_img.mode, (height, height), background_color) - result.paste(pil_img, ((height - width) // 2, 0)) - return result - - -def process_images(images, image_processor, model_cfg): - image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) - new_images = [] - if image_aspect_ratio == "highres": - for image in images: - image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints) - new_images.append(image) - elif image_aspect_ratio == "anyres": - for image in images: - image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) - new_images.append(image) - elif image_aspect_ratio == "crop_split": - for image in images: - image = process_highres_image_crop_split(image, model_cfg, image_processor) - new_images.append(image) - elif image_aspect_ratio == "pad": - for image in images: - image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) - image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] - new_images.append(image) - else: - return image_processor(images, return_tensors="pt")["pixel_values"] - if all(x.shape == new_images[0].shape for x in new_images): - new_images = torch.stack(new_images, dim=0) - return new_images - - -def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): - prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] - - def insert_separator(X, sep): - return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] - - input_ids = [] - offset = 0 - if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: - offset = 1 - input_ids.append(prompt_chunks[0][0]) - - for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): - input_ids.extend(x[offset:]) - - if return_tensors is not None: - if return_tensors == "pt": - return torch.tensor(input_ids, dtype=torch.long) - raise ValueError(f"Unsupported tensor type: {return_tensors}") - return input_ids - - -def get_model_name_from_path(model_path): - model_path = model_path.strip("/") - model_paths = model_path.split("/") - if model_paths[-1].startswith("checkpoint-"): - return model_paths[-2] + "_" + model_paths[-1] - else: - return model_paths[-1] - - -class KeywordsStoppingCriteria(StoppingCriteria): - def __init__(self, keywords, tokenizer, input_ids): - self.keywords = keywords - self.keyword_ids = [] - for keyword in keywords: - cur_keyword_ids = tokenizer(keyword).input_ids - if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: - cur_keyword_ids = cur_keyword_ids[1:] - self.keyword_ids.append(torch.tensor(cur_keyword_ids)) - self.tokenizer = tokenizer - self.start_len = input_ids.shape[1] - - def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO - offset = min(output_ids.shape[1] - self.start_len, 3) - self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] - for keyword_id in self.keyword_ids: - if output_ids[0, -keyword_id.shape[0] :] == keyword_id: - return True - outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] - for keyword in self.keywords: - if keyword in outputs: - return True - return False diff --git a/llava/model/__init__.py b/llava/model/__init__.py deleted file mode 100644 index 5fb3442cb7f2defc87df1f5d95b7e8888bfd7289..0000000000000000000000000000000000000000 --- a/llava/model/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - -AVAILABLE_MODELS = { - "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", - "llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig", - "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", - # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", - "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", - "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", - # Add other models as needed -} - -for model_name, model_classes in AVAILABLE_MODELS.items(): - try: - exec(f"from .language_model.{model_name} import {model_classes}") - except ImportError: - # import traceback - # traceback.print_exc() - print(f"Failed to import {model_name} from llava.language_model.{model_name}") - pass diff --git a/llava/model/__pycache__/__init__.cpython-310.pyc b/llava/model/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 28702f2ae90d7787204b82e6d6cd7a8d94f8964d..0000000000000000000000000000000000000000 Binary files a/llava/model/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/llava/model/__pycache__/builder.cpython-310.pyc b/llava/model/__pycache__/builder.cpython-310.pyc deleted file mode 100644 index c5b07ac66eaffa4db311dfc56154a1f6e4af5bb3..0000000000000000000000000000000000000000 Binary files a/llava/model/__pycache__/builder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/__pycache__/llava_arch.cpython-310.pyc b/llava/model/__pycache__/llava_arch.cpython-310.pyc deleted file mode 100644 index 3a45b79a2331785cdb18d77ae63611853da5d356..0000000000000000000000000000000000000000 Binary files a/llava/model/__pycache__/llava_arch.cpython-310.pyc and /dev/null differ diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py deleted file mode 100644 index c183ba19a4e91e9cb95155b542e7406ea5b287a0..0000000000000000000000000000000000000000 --- a/llava/model/apply_delta.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Usage: -python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta -""" - -import argparse - -import torch -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM -from llava import LlavaLlamaForCausalLM - - -def apply_delta(base_model_path, target_model_path, delta_path): - print("Loading base model") - base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - - print("Loading delta") - delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) - - print("Applying delta") - for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): - if name not in base.state_dict(): - assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" - continue - if param.data.shape == base.state_dict()[name].shape: - param.data += base.state_dict()[name] - else: - assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" - bparam = base.state_dict()[name] - param.data[: bparam.shape[0], : bparam.shape[1]] += bparam - - print("Saving target model") - delta.save_pretrained(target_model_path) - delta_tokenizer.save_pretrained(target_model_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--base-model-path", type=str, required=True) - parser.add_argument("--target-model-path", type=str, required=True) - parser.add_argument("--delta-path", type=str, required=True) - - args = parser.parse_args() - - apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/llava/model/builder.py b/llava/model/builder.py deleted file mode 100644 index e9bc92b9875cf2a4a04d11aaa7798615054be4f3..0000000000000000000000000000000000000000 --- a/llava/model/builder.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import warnings -import shutil - -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig -import torch -from llava.model import * -from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -from llava.utils import rank0_print - - -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, **kwargs): - kwargs = {"device_map": device_map} - - if load_8bit: - kwargs["load_in_8bit"] = True - elif load_4bit: - kwargs["load_in_4bit"] = True - kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") - else: - kwargs["torch_dtype"] = torch.float16 - - if customized_config is not None: - kwargs["config"] = customized_config - - if "llava" in model_name.lower(): - # Load LLaVA model - if "lora" in model_name.lower() and model_base is None: - warnings.warn( - "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." - ) - if "lora" in model_name.lower() and model_base is not None: - lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - rank0_print("Loading LLaVA from base model...") - if "mixtral" in model_name.lower(): - from llava.model.language_model.llava_mixtral import LlavaMixtralConfig - - lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - elif "mistral" in model_name.lower(): - from llava.model.language_model.llava_mistral import LlavaMistralConfig - - lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - elif "gemma" in model_name.lower(): - from llava.model.language_model.llava_gemma import LlavaGemmaConfig - - lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - else: - from llava.model.language_model.llava_llama import LlavaConfig - - lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - - token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features - if model.lm_head.weight.shape[0] != token_num: - model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) - model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) - - rank0_print("Loading additional LLaVA weights...") - if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): - non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu") - else: - # this is probably from HF Hub - from huggingface_hub import hf_hub_download - - def load_from_hf(repo_id, filename, subfolder=None): - cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder) - return torch.load(cache_file, map_location="cpu") - - non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin") - non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()} - if any(k.startswith("model.model.") for k in non_lora_trainables): - non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()} - model.load_state_dict(non_lora_trainables, strict=False) - - from peft import PeftModel - - rank0_print("Loading LoRA weights...") - model = PeftModel.from_pretrained(model, model_path) - rank0_print("Merging LoRA weights...") - model = model.merge_and_unload() - rank0_print("Model is loaded...") - elif model_base is not None: - # this may be mm projector only - rank0_print(f"Loading LLaVA from base model {model_base}...") - if "mixtral" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - cfg_pretrained = AutoConfig.from_pretrained(model_path) - model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - elif "mistral" in model_name.lower() or "zephyr" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - cfg_pretrained = AutoConfig.from_pretrained(model_path) - model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - elif "gemma" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - cfg_pretrained = AutoConfig.from_pretrained(model_path) - model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - elif ( - "wizardlm-2" in model_name.lower() - and "vicuna" in model_name.lower() - or "llama" in model_name.lower() - or "yi" in model_name.lower() - or "nous-hermes" in model_name.lower() - or "llava-v1.6-34b" in model_name.lower() - or "llava-v1.5" in model_name.lower() - ): - from llava.model.language_model.llava_llama import LlavaConfig - - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if customized_config is None: - llava_cfg = LlavaConfig.from_pretrained(model_path) - if "v1.5" in model_name.lower(): - llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models - else: - llava_cfg = customized_config - - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - llava_cfg = LlavaConfig.from_pretrained(model_path) - model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs) - else: - raise ValueError(f"Model {model_name} not supported") - - mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu") - mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} - model.load_state_dict(mm_projector_weights, strict=False) - else: - rank0_print(f"Loaded LLaVA model: {model_path}") - if "mixtral" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path) - model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) - elif "mistral" in model_name.lower() or "zephyr" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path) - model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) - elif ( - "wizardlm-2" in model_name.lower() - and "vicuna" in model_name.lower() - or "llama" in model_name.lower() - or "yi" in model_name.lower() - or "nous-hermes" in model_name.lower() - or "llava-v1.6-34b" in model_name.lower() - or "llava-v1.5" in model_name.lower() - ): - from llava.model.language_model.llava_llama import LlavaConfig - - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if customized_config is None: - llava_cfg = LlavaConfig.from_pretrained(model_path) - if "v1.5" in model_name.lower(): - llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models - else: - llava_cfg = customized_config - - model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) - elif "qwen" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) - elif "gemma" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - cfg_pretrained = AutoConfig.from_pretrained(model_path) - model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs) - else: - rank0_print("\n\n\nWarning : No matching llava architecture, auto load llava_llama. If it is not intended, specify it in model_name\n\n\n") - try: - from llava.model.language_model.llava_llama import LlavaConfig - - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - if customized_config is None: - llava_cfg = LlavaConfig.from_pretrained(model_path) - if "v1.5" in model_path.lower(): - llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models - else: - llava_cfg = customized_config - model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) - except: - raise ValueError(f"Model {model_name} not supported") - - else: - # Load language model - if model_base is not None: - # PEFT model - from peft import PeftModel - - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") - print(f"Loading LoRA weights from {model_path}") - model = PeftModel.from_pretrained(model, model_path) - print(f"Merging weights") - model = model.merge_and_unload() - print("Convert to FP16...") - model.to(torch.float16) - else: - use_fast = False - if "mpt" in model_name.lower().replace("prompt", ""): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) - model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) - else: - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) - - rank0_print(f"Model Class: {model.__class__.__name__}") - image_processor = None - - if "llava" in model_name.lower(): - mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) - mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) - if mm_use_im_patch_token: - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) - model.resize_token_embeddings(len(tokenizer)) - - vision_tower = model.get_vision_tower() - if not vision_tower.is_loaded: - vision_tower.load_model(device_map=device_map) - if device_map != "auto": - vision_tower.to(device="cuda", dtype=torch.float16) - image_processor = vision_tower.image_processor - - if hasattr(model.config, "max_sequence_length"): - context_len = model.config.max_sequence_length - elif hasattr(model.config, "max_position_embeddings"): - context_len = model.config.max_position_embeddings - elif hasattr(model.config, "tokenizer_model_max_length"): - context_len = model.config.tokenizer_model_max_length - else: - context_len = 2048 - - return tokenizer, model, image_processor, context_len diff --git a/llava/model/consolidate.py b/llava/model/consolidate.py deleted file mode 100644 index f02e575f6b8e4388e1758776cadd62309147a1ad..0000000000000000000000000000000000000000 --- a/llava/model/consolidate.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Usage: -python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate -""" - -import argparse - -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM -from llava.model import * -from llava.model.utils import auto_upgrade - - -def consolidate_ckpt(src_path, dst_path): - print("Loading model") - auto_upgrade(src_path) - src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) - src_model.save_pretrained(dst_path) - src_tokenizer.save_pretrained(dst_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--src", type=str, required=True) - parser.add_argument("--dst", type=str, required=True) - - args = parser.parse_args() - - consolidate_ckpt(args.src, args.dst) diff --git a/llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc deleted file mode 100644 index 0f34884dd719268aad38d38c96cee2c1709b5d60..0000000000000000000000000000000000000000 Binary files a/llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc and /dev/null differ diff --git a/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc deleted file mode 100644 index 207afcc4973eb39b256a715ec2cfa0e3dddd990a..0000000000000000000000000000000000000000 Binary files a/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc and /dev/null differ diff --git a/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc deleted file mode 100644 index b90b45000c9a24cda553a6a6cebf39d1b8d0a785..0000000000000000000000000000000000000000 Binary files a/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc and /dev/null differ diff --git a/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc deleted file mode 100644 index 83480b115fb08e8f2972bfc25ea188e1140dd60e..0000000000000000000000000000000000000000 Binary files a/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc and /dev/null differ diff --git a/llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc deleted file mode 100644 index fd9d403c53e1f4b8840cd218ebe5d10e0812e910..0000000000000000000000000000000000000000 Binary files a/llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc and /dev/null differ diff --git a/llava/model/language_model/llava_gemma.py b/llava/model/language_model/llava_gemma.py deleted file mode 100644 index 5c0ac173017034bbbb03b158067d4e4f7ff970f6..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_gemma.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss - -from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM - -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM - - -class LlavaGemmaConfig(GemmaConfig): - model_type = "llava_gemma" - - -class LlavaGemmaModel(LlavaMetaModel, GemmaModel): - config_class = LlavaGemmaConfig - - def __init__(self, config: GemmaConfig): - super(LlavaGemmaModel, self).__init__(config) - - -class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaGemmaConfig - - def __init__(self, config): - super(GemmaForCausalLM, self).__init__(config) - self.model = LlavaGemmaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_gemma", LlavaGemmaConfig) -AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py deleted file mode 100644 index b00dd29eb41c1e27f531a7994cf8d6be715f80d3..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_llama.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig - -# , LlamaModel, LlamaForCausalLM, GenerationConfig -# from .modeling_llama import LlamaModel, LlamaForCausalLM -from transformers import LlamaModel, LlamaForCausalLM -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM - - -class LlavaConfig(LlamaConfig): - model_type = "llava_llama" - temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna - max_new_tokens: int = 1024 - do_sample: bool = False - top_p: Optional[float] = None - rope_scaling: Optional[dict] = {} - - -class LlavaLlamaModel(LlavaMetaModel, LlamaModel): - config_class = LlavaConfig - - def __init__(self, config: LlamaConfig): - super(LlavaLlamaModel, self).__init__(config) - - -class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaConfig - - def __init__(self, config): - LlamaForCausalLM.__init__(self, config) - - # configure default generation settings - config.model_type = "llava_llama" - config.rope_scaling = None - - self.model = LlavaLlamaModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_llama", LlavaConfig) -AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/llava/model/language_model/llava_mistral.py b/llava/model/language_model/llava_mistral.py deleted file mode 100644 index 2cc3b0157be23db5ed4aab0c0802f40b4702e116..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_mistral.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss - -from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig - -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM - - -class LlavaMistralConfig(MistralConfig): - model_type = "llava_mistral" - temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna - max_new_tokens: int = 1024 - do_sample: bool = False - top_p: Optional[float] = None - - -class LlavaMistralModel(LlavaMetaModel, MistralModel): - config_class = LlavaMistralConfig - - def __init__(self, config: MistralConfig): - super(LlavaMistralModel, self).__init__(config) - - -class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaMistralConfig - - def __init__(self, config): - super(MistralForCausalLM, self).__init__(config) - - config.model_type = "llava_mistral" - config.rope_scaling = None - - self.model = LlavaMistralModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_mistral", LlavaMistralConfig) -AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) diff --git a/llava/model/language_model/llava_mixtral.py b/llava/model/language_model/llava_mixtral.py deleted file mode 100644 index a9090abc236ef7cae263d531f695db1a4ce68a79..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_mixtral.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss - -from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig - -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM - - -class LlavaMixtralConfig(MixtralConfig): - model_type = "llava_mixtral" - - -class LlavaMixtralModel(LlavaMetaModel, MixtralModel): - config_class = LlavaMixtralConfig - - def __init__(self, config: MixtralConfig): - super(LlavaMixtralModel, self).__init__(config) - - -class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaMixtralConfig - - def __init__(self, config): - super(MixtralForCausalLM, self).__init__(config) - - config.model_type = "llava_mixtral" - config.rope_scaling = None - self.model = LlavaMixtralModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_mixtral", LlavaMixtralConfig) -AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) diff --git a/llava/model/language_model/llava_mpt.py b/llava/model/language_model/llava_mpt.py deleted file mode 100644 index c3bce7d37e9c0bff41cf8a5fe2f99d29d7cc4495..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_mpt.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional, Tuple - -import torch - -from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig -from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM - - -class LlavaMptConfig(MptConfig): - model_type = "llava_mpt" - - -class LlavaMptModel(LlavaMetaModel, MptModel): - config_class = LlavaMptConfig - - def __init__(self, config: MptConfig): - config.hidden_size = config.d_model - super(LlavaMptModel, self).__init__(config) - - def embed_tokens(self, x): - return self.wte(x) - - -class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaMptConfig - supports_gradient_checkpointing = True - - def __init__(self, config): - super(MptForCausalLM, self).__init__(config) - - config.model_type = "llava_mpt" - config.rope_scaling = None - self.generation_config = GenerationConfig( - temperature=0.0, - max_new_tokens=1024, - do_sample=False, - top_p=None, - ) - - self.transformer = LlavaMptModel(config) - self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.transformer - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlavaMptModel): - module.gradient_checkpointing = value - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position=None, - images=None, - ): - - input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) - - return super().forward( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - _inputs["images"] = images - return _inputs - - -AutoConfig.register("llava_mpt", LlavaMptConfig) -AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) diff --git a/llava/model/language_model/llava_qwen.py b/llava/model/language_model/llava_qwen.py deleted file mode 100644 index 1f681cef24abc6e6e7c31a9ed763ebbce25b4aec..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_qwen.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2024 Hao Zhang -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union, Dict -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss - -import transformers -from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM - -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM -from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM - -# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel -# from .qwen.configuration_qwen import QWenConfig - - -class LlavaQwenConfig(Qwen2Config): - model_type = "llava_qwen" - - -class LlavaQwenModel(LlavaMetaModel, Qwen2Model): - config_class = LlavaQwenConfig - - def __init__(self, config: Qwen2Config): - super(LlavaQwenModel, self).__init__(config) - - -class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaQwenConfig - - def __init__(self, config): - # super(Qwen2ForCausalLM, self).__init__(config) - Qwen2ForCausalLM.__init__(self, config) - config.model_type = "llava_qwen" - config.rope_scaling = None - - self.model = LlavaQwenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_qwen", LlavaQwenConfig) -AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM) diff --git a/llava/model/language_model/llava_qwen_moe.py b/llava/model/language_model/llava_qwen_moe.py deleted file mode 100644 index 08c39667e85cd7a51d3ea99ccf7f22c71a1030bf..0000000000000000000000000000000000000000 --- a/llava/model/language_model/llava_qwen_moe.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2024 Hao Zhang -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union, Dict -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss - -import transformers -from transformers import AutoConfig, AutoModelForCausalLM - -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.generation.utils import GenerateOutput - -# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM -from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM - -# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel -# from .qwen.configuration_qwen import QWenConfig - - -class LlavaQwenMoeConfig(Qwen2MoeConfig): - model_type = "llava_qwen_moe" - - -class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel): - config_class = LlavaQwenMoeConfig - - def __init__(self, config: Qwen2MoeConfig): - super(LlavaQwenMoeModel, self).__init__(config) - - -class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM): - config_class = LlavaQwenMoeConfig - - def __init__(self, config): - # super(Qwen2MoeForCausalLM, self).__init__(config) - Qwen2MoeForCausalLM.__init__(self, config) - config.model_type = "llava_qwen_moe" - config.rope_scaling = None - - self.model = LlavaQwenMoeModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - image_sizes: Optional[List[List[int]]] = None, - return_dict: Optional[bool] = None, - cache_position=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_sizes: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - position_ids = kwargs.pop("position_ids", None) - attention_mask = kwargs.pop("attention_mask", None) - if "inputs_embeds" in kwargs: - raise NotImplementedError("`inputs_embeds` is not supported") - - if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) - else: - inputs_embeds = self.get_model().embed_tokens(inputs) - - return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - images = kwargs.pop("images", None) - image_sizes = kwargs.pop("image_sizes", None) - inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) - if images is not None: - inputs["images"] = images - if image_sizes is not None: - inputs["image_sizes"] = image_sizes - return inputs - - -AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig) -AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py deleted file mode 100644 index 688c3bcb7c285fe375d1f0874c3d1440eaa96a57..0000000000000000000000000000000000000000 --- a/llava/model/llava_arch.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright 2023 Haotian Liu -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from abc import ABC, abstractmethod - -import torch -import torch.nn as nn - -from .multimodal_encoder.builder import build_vision_tower -from .multimodal_resampler.builder import build_vision_resampler -from .multimodal_projector.builder import build_vision_projector - -from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN - -from llava.mm_utils import get_anyres_image_grid_shape -from llava.utils import rank0_print - - -class LlavaMetaModel: - - def __init__(self, config): - super(LlavaMetaModel, self).__init__(config) - - if hasattr(config, "mm_vision_tower"): - delay_load = getattr(config, "delay_load", False) - self.vision_tower = build_vision_tower(config, delay_load=delay_load) - self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) - self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) - - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) - - def get_vision_tower(self): - vision_tower = getattr(self, "vision_tower", None) - if type(vision_tower) is list: - vision_tower = vision_tower[0] - return vision_tower - - def initialize_vision_modules(self, model_args, fsdp=None): - vision_tower = model_args.vision_tower - mm_vision_select_layer = model_args.mm_vision_select_layer - mm_vision_select_feature = model_args.mm_vision_select_feature - pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter - mm_patch_merge_type = model_args.mm_patch_merge_type - - self.config.mm_vision_tower = vision_tower - self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") - - if self.get_vision_tower() is None: - vision_tower = build_vision_tower(model_args) - vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) - for k, v in vision_resampler.config.items(): - setattr(self.config, k, v) - - if fsdp is not None and len(fsdp) > 0: - self.vision_tower = [vision_tower] - self.vision_resampler = [vision_resampler] - else: - self.vision_tower = vision_tower - self.vision_resampler = vision_resampler - else: - if fsdp is not None and len(fsdp) > 0: - vision_resampler = self.vision_resampler[0] - vision_tower = self.vision_tower[0] - else: - vision_resampler = self.vision_resampler - vision_tower = self.vision_tower - vision_tower.load_model() - - # In case it is frozen by LoRA - for p in self.vision_resampler.parameters(): - p.requires_grad = True - - self.config.use_mm_proj = True - self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") - self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) - self.config.mm_vision_select_layer = mm_vision_select_layer - self.config.mm_vision_select_feature = mm_vision_select_feature - self.config.mm_patch_merge_type = mm_patch_merge_type - - if getattr(self, "mm_projector", None) is None: - self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) - - if "unpad" in mm_patch_merge_type: - embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) - self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) - else: - # In case it is frozen by LoRA - for p in self.mm_projector.parameters(): - p.requires_grad = True - - if pretrain_mm_mlp_adapter is not None: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") - - def get_w(weights, keyword): - return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} - - incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) - rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") - incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) - rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") - - -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. - original_size (tuple): The original size of the image (height, width). - - Returns: - torch.Tensor: The unpadded image tensor. - """ - original_width, original_height = original_size - current_height, current_width = tensor.shape[1:] - - # Compute aspect ratios - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - # Determine padding size and direction - if original_aspect_ratio > current_aspect_ratio: - # Padding was added to the height - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - # Padding was added to the width - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - -class LlavaMetaForCausalLM(ABC): - - @abstractmethod - def get_model(self): - pass - - def get_vision_tower(self): - return self.get_model().get_vision_tower() - - def encode_images(self, images): - image_features = self.get_model().get_vision_tower()(images) - image_features = self.get_model().vision_resampler(image_features, images=images) - image_features = self.get_model().mm_projector(image_features) - return image_features - - def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None): - vision_tower = self.get_vision_tower() - if vision_tower is None or images is None or input_ids.shape[1] == 1: - return input_ids, position_ids, attention_mask, past_key_values, None, labels - - if type(images) is list or images.ndim == 5: - if type(images) is list: - images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) - mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") - image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") - if mm_patch_merge_type == "flat": - image_features = [x.flatten(0, 1) for x in image_features] - elif mm_patch_merge_type.startswith("spatial"): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - # FIXME: now assume the image is square, and split to 2x2 patches - # num_patches = h * w, where h = w = sqrt(num_patches) - # currently image_feature is a tensor of shape (4, num_patches, hidden_size) - # we want to first unflatten it to (2, 2, h, w, hidden_size) - - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.get_vision_tower().num_patches_per_side - assert height * width == base_image_feature.shape[0] - if image_aspect_ratio == "anyres": - if hasattr(self.get_vision_tower(), "image_size"): - vision_tower_image_size = self.get_vision_tower().image_size - else: - raise ValueError("vision_tower_image_size is not found in the vision tower.") - num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) - else: - image_feature = image_feature.view(2, 2, height, width, -1) - if "maxpool2x2" in mm_patch_merge_type: - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = nn.functional.max_pool2d(image_feature, 2) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - elif "unpad" in mm_patch_merge_type: - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - else: - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.flatten(0, 3) - if "nobase" in mm_patch_merge_type: - pass - else: - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if "unpad" in mm_patch_merge_type: - image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) - new_image_features.append(image_feature) - image_features = new_image_features - else: - raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") - else: - image_features = self.encode_images(images) - - # TODO: image start / end is not implemented here to support pretraining. - if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): - raise NotImplementedError - - # Let's just add dummy tensors if they do not exist, - # it is a headache to deal with None all the time. - # But it is not ideal, and if you have a better idea, - # please open an issue / submit a PR, thanks. - _labels = labels - _position_ids = position_ids - _attention_mask = attention_mask - if attention_mask is None: - attention_mask = torch.ones_like(input_ids, dtype=torch.bool) - else: - attention_mask = attention_mask.bool() - if position_ids is None: - position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) - if labels is None: - labels = torch.full_like(input_ids, IGNORE_INDEX) - - # remove the padding using attention_mask -- FIXME - _input_ids = input_ids - input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] - labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] - - new_input_embeds = [] - new_labels = [] - cur_image_idx = 0 - for batch_idx, cur_input_ids in enumerate(input_ids): - num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() - if num_images == 0: - cur_image_features = image_features[cur_image_idx] - cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) - cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) - new_input_embeds.append(cur_input_embeds) - new_labels.append(labels[batch_idx]) - cur_image_idx += 1 - continue - - image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] - cur_input_ids_noim = [] - cur_labels = labels[batch_idx] - cur_labels_noim = [] - for i in range(len(image_token_indices) - 1): - cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) - cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) - split_sizes = [x.shape[0] for x in cur_labels_noim] - cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) - cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) - cur_new_input_embeds = [] - cur_new_labels = [] - - for i in range(num_images + 1): - cur_new_input_embeds.append(cur_input_embeds_no_im[i]) - cur_new_labels.append(cur_labels_noim[i]) - if i < num_images: - cur_image_features = image_features[cur_image_idx] - cur_image_idx += 1 - cur_new_input_embeds.append(cur_image_features) - cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) - - cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] - - cur_new_input_embeds = torch.cat(cur_new_input_embeds) - cur_new_labels = torch.cat(cur_new_labels) - - new_input_embeds.append(cur_new_input_embeds) - new_labels.append(cur_new_labels) - - # Truncate sequences to max length as image embeddings can make the sequence longer - tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) - if tokenizer_model_max_length is not None: - new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] - new_labels = [x[:tokenizer_model_max_length] for x in new_labels] - - # Combine them - max_len = max(x.shape[0] for x in new_input_embeds) - batch_size = len(new_input_embeds) - - new_input_embeds_padded = [] - new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) - attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) - position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) - - for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): - cur_len = cur_new_embed.shape[0] - if getattr(self.config, "tokenizer_padding_side", "right") == "left": - new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) - if cur_len > 0: - new_labels_padded[i, -cur_len:] = cur_new_labels - attention_mask[i, -cur_len:] = True - position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) - else: - new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) - if cur_len > 0: - new_labels_padded[i, :cur_len] = cur_new_labels - attention_mask[i, :cur_len] = True - position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) - - new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) - - if _labels is None: - new_labels = None - else: - new_labels = new_labels_padded - - if _attention_mask is None: - attention_mask = None - else: - attention_mask = attention_mask.to(dtype=_attention_mask.dtype) - - if _position_ids is None: - position_ids = None - - return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels - - def initialize_vision_tokenizer(self, model_args, tokenizer): - if model_args.mm_use_im_patch_token: - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - self.resize_token_embeddings(len(tokenizer)) - - if model_args.mm_use_im_start_end: - num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) - self.resize_token_embeddings(len(tokenizer)) - - if num_new_tokens > 0: - input_embeddings = self.get_input_embeddings().weight.data - output_embeddings = self.get_output_embeddings().weight.data - - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - - input_embeddings[-num_new_tokens:] = input_embeddings_avg - output_embeddings[-num_new_tokens:] = output_embeddings_avg - - if model_args.tune_mm_mlp_adapter: - for p in self.get_input_embeddings().parameters(): - p.requires_grad = True - for p in self.get_output_embeddings().parameters(): - p.requires_grad = False - - if model_args.pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") - embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] - assert num_new_tokens == 2 - if input_embeddings.shape == embed_tokens_weight.shape: - input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] - elif embed_tokens_weight.shape[0] == num_new_tokens: - input_embeddings[-num_new_tokens:] = embed_tokens_weight - else: - raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") - elif model_args.mm_use_im_patch_token: - if model_args.tune_mm_mlp_adapter: - for p in self.get_input_embeddings().parameters(): - p.requires_grad = False - for p in self.get_output_embeddings().parameters(): - p.requires_grad = False diff --git a/llava/model/make_delta.py b/llava/model/make_delta.py deleted file mode 100644 index 7b3fbabe19506d5710dbc194db4000fee62c712d..0000000000000000000000000000000000000000 --- a/llava/model/make_delta.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Usage: -python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta -""" - -import argparse - -import torch -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM -from llava.model.utils import auto_upgrade - - -def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): - print("Loading base model") - base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - - print("Loading target model") - auto_upgrade(target_model_path) - target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - - print("Calculating delta") - for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): - if name not in base.state_dict(): - assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" - continue - if param.data.shape == base.state_dict()[name].shape: - param.data -= base.state_dict()[name] - else: - assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" - bparam = base.state_dict()[name] - param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam - - print("Saving delta") - if hub_repo_id: - kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} - else: - kwargs = {} - target.save_pretrained(delta_path, **kwargs) - target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) - target_tokenizer.save_pretrained(delta_path, **kwargs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--base-model-path", type=str, required=True) - parser.add_argument("--target-model-path", type=str, required=True) - parser.add_argument("--delta-path", type=str, required=True) - parser.add_argument("--hub-repo-id", type=str, default=None) - args = parser.parse_args() - - make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) diff --git a/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc deleted file mode 100644 index 5b9a5268b4017b541bf0290fa1ae4c440c325a6f..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc deleted file mode 100644 index 3f3e48dba3b07e1cfc06fab3af254eaf395ca34d..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc deleted file mode 100644 index 45076841d2b32b3647d3455381a0b20c14c068e8..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py deleted file mode 100644 index b06ce74fcb3afac011e76281064979c8e90b1892..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_encoder/builder.py +++ /dev/null @@ -1,14 +0,0 @@ -import os -from .clip_encoder import CLIPVisionTower -from .siglip_encoder import SigLipVisionTower - - -def build_vision_tower(vision_tower_cfg, **kwargs): - vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) - is_absolute_path_exists = os.path.exists(vision_tower) - if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: - return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) - elif "siglip" in vision_tower: - return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) - - raise ValueError(f'Unknown vision tower: {vision_tower}') diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py deleted file mode 100644 index 3d46db5540a3dbed9d895858ea496d33a6da9805..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_encoder/clip_encoder.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -import torch.nn as nn - -from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig - - -class CLIPVisionTower(nn.Module): - def __init__(self, vision_tower, args, delay_load=False): - super().__init__() - - self.is_loaded = False - - self.vision_tower_name = vision_tower - self.select_layer = args.mm_vision_select_layer - self.select_feature = getattr(args, "mm_vision_select_feature", "patch") - - if not delay_load: - self.load_model() - elif getattr(args, "unfreeze_mm_vision_tower", False): - # TODO: better detector is needed. - print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") - self.load_model() - else: - self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) - - def load_model(self, device_map=None): - if self.is_loaded: - print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) - return - - # import pdb; pdb.set_trace() - self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) - self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) - self.vision_tower.requires_grad_(False) - - self.is_loaded = True - - def feature_select(self, image_forward_outs): - select_feature_type = self.select_feature - - if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: - select_every_k_layer = len(image_forward_outs.hidden_states) // 4 - image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) - select_feature_type = select_feature_type.replace("slicefour_", "") - elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: - select_layers = [-2, -5, -8, -11, 6] - image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) - select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") - else: - image_features = image_forward_outs.hidden_states[self.select_layer] - - if select_feature_type == "patch": - image_features = image_features[:, 1:] - elif select_feature_type == "cls_patch": - image_features = image_features - else: - raise ValueError(f"Unexpected select feature: {select_feature_type}") - return image_features - - def forward(self, images): - if type(images) is list: - image_features = [] - for image in images: - image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) - image_feature = self.feature_select(image_forward_out).to(image.dtype) - image_features.append(image_feature) - else: - image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) - image_features = self.feature_select(image_forward_outs).to(images.dtype) - - return image_features - - @property - def dummy_feature(self): - return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) - - @property - def dtype(self): - return self.vision_tower.dtype - - @property - def device(self): - return self.vision_tower.device - - @property - def config(self): - if self.is_loaded: - return self.vision_tower.config - else: - return self.cfg_only - - @property - def hidden_size(self): - _hidden_size = self.config.hidden_size - if "slicefour" in self.select_feature: - _hidden_size *= 4 - if "slice_m25811_f6" in self.select_feature: - _hidden_size *= 5 - return _hidden_size - - @property - def num_patches_per_side(self): - return self.config.image_size // self.config.patch_size - - @property - def num_patches(self): - _num_patches = (self.config.image_size // self.config.patch_size) ** 2 - if "cls_patch" in self.select_feature: - _num_patches += 1 - return _num_patches - - @property - def image_size(self): - return self.config.image_size \ No newline at end of file diff --git a/llava/model/multimodal_encoder/siglip_encoder.py b/llava/model/multimodal_encoder/siglip_encoder.py deleted file mode 100644 index e6d297fc9f582f939bfe1809d4441f38091d54f9..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_encoder/siglip_encoder.py +++ /dev/null @@ -1,620 +0,0 @@ -""" -# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py -""" - -from typing import Optional, Tuple, Union, Dict -from dataclasses import dataclass -from functools import partial, reduce -from PIL import Image -import torch -import torch.utils.checkpoint -from torch import nn -import os -from transformers.image_processing_utils import BatchFeature, get_size_dict -from transformers.image_transforms import ( - convert_to_rgb, - normalize, - rescale, - resize, - to_channel_dimension_format, -) -from transformers.image_utils import ( - ChannelDimension, - PILImageResampling, - to_numpy_array, -) -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from transformers import PretrainedConfig -from transformers.utils import ModelOutput -from llava.utils import rank0_print - - -class SigLipImageProcessor: - def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): - crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} - crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") - - self.image_mean = image_mean - self.image_std = image_std - self.size = size - self.resample = resample - self.rescale_factor = rescale_factor - self.data_format = data_format - self.crop_size = crop_size - - def preprocess(self, images, return_tensors): - if isinstance(images, Image.Image): - images = [images] - else: - assert isinstance(images, list) - - transforms = [ - convert_to_rgb, - to_numpy_array, - partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), - partial(rescale, scale=self.rescale_factor, data_format=self.data_format), - partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), - partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), - ] - - images = reduce(lambda x, f: [*map(f, x)], transforms, images) - data = {"pixel_values": images} - - return BatchFeature(data=data, tensor_type=return_tensors) - - -class SigLipVisionConfig(PretrainedConfig): - model_type = "siglip_vision_model" - - def __init__( - self, - hidden_size=1152, - image_mean=(0.5, 0.5, 0.5), - intermediate_size=4304, - num_hidden_layers=27, - num_attention_heads=16, - num_channels=3, - image_size=384, - patch_size=14, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - self.image_mean = image_mean - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - - # get the vision config dict if we are loading from SigLipConfig - if config_dict.get("model_type") == "siglip": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.") - - return cls.from_dict(config_dict, **kwargs) - - -@dataclass -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip -class SigLipVisionModelOutput(ModelOutput): - """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - - Args: - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class SigLipVisionEmbeddings(nn.Module): - def __init__(self, config: SigLipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -class SigLipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip -class SigLipMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip -class SigLipEncoderLayer(nn.Module): - def __init__(self, config: SigLipVisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = SigLipAttention(config) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SigLipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - - # Ignore copy - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class SigLipPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SigLipVisionConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - """Initialize the weights""" - pass - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip -class SigLipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SigLipEncoderLayer`]. - - Args: - config: SigLipVisionConfig - """ - - def __init__(self, config: SigLipVisionConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - # Ignore copy - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) - - -class SigLipVisionTransformer(nn.Module): - def __init__(self, config: SigLipVisionConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SigLipVisionEmbeddings(config) - self.encoder = SigLipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.head = SigLipMultiheadAttentionPoolingHead(config) - - def forward( - self, - pixel_values, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - hidden_states = self.embeddings(pixel_values) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooled_output = self.head(last_hidden_state) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class SigLipMultiheadAttentionPoolingHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, config: SigLipVisionConfig): - super().__init__() - - self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = SigLipMLP(config) - - def forward(self, hidden_state): - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) - - hidden_state = self.attention(probe, hidden_state, hidden_state)[0] - - residual = hidden_state - hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) - - return hidden_state[:, 0] - - -class SigLipVisionModel(SigLipPreTrainedModel): - config_class = SigLipVisionConfig - main_input_name = "pixel_values" - _no_split_modules = ["SigLipEncoderLayer"] - - def __init__(self, config: SigLipVisionConfig): - super().__init__(config) - - self.vision_model = SigLipVisionTransformer(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - - def forward( - self, - pixel_values, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, SigLipVisionModel - - >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled features - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - return self.vision_model( - pixel_values=pixel_values.to(self.device), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class SigLipVisionTower(nn.Module): - def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): - super().__init__() - - self.is_loaded = False - - self.config = SigLipVisionConfig() - - self.vision_tower_name = vision_tower - - self.image_processor = SigLipImageProcessor() - - if not delay_load: - rank0_print(f"Loading vision tower: {vision_tower}") - self.load_model() - elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): - # TODO: better detector is needed. - rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") - self.load_model() - elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: - rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") - self.load_model() - else: - self.cfg_only = self.config - - def load_model(self, device_map=None): - if self.is_loaded: - return - - self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) - - del self.vision_tower.vision_model.encoder.layers[-1:] - self.vision_tower.vision_model.head = nn.Identity() - self.vision_tower.requires_grad_(False) - self.vision_tower.eval() - - self.is_loaded = True - - @torch.no_grad() - def forward(self, images): - if type(images) is list: - image_features = [] - for image in images: - image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) - image_feature = image_forward_out.hidden_states[-1].to(image.dtype) - assert image_features.shape[-2] == 729 - image_features.append(image_feature) - else: - images=images.to(device=self.device, dtype=self.dtype) - image_forward_outs = self.vision_tower(images, output_hidden_states=True) - image_features = image_forward_outs.hidden_states[-1].to(images.dtype) - assert image_features.shape[-2] == 729 - - return image_features - - @property - def dummy_feature(self): - return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) - - @property - def dtype(self): - for p in self.vision_tower.parameters(): - return p.dtype - - @property - def device(self): - for p in self.vision_tower.parameters(): - return p.device - - @property - def hidden_size(self): - return self.config.hidden_size - - @property - def num_patches(self): - return (self.config.image_size // self.config.patch_size) ** 2 - - @property - def num_patches_per_side(self): - return self.config.image_size // self.config.patch_size - # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] - - @property - def image_size(self): - return self.config.image_size diff --git a/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc b/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc deleted file mode 100644 index 875b01d47b1840feed01bdae58e6a145f9ab27f8..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc b/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc deleted file mode 100644 index d5eb78247dae00f8bb91731c0e43e1b0bd36df0c..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_projector/builder.py b/llava/model/multimodal_projector/builder.py deleted file mode 100644 index 3122a0c3bc5b50f1c921ba9b186fd736f018c9cf..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_projector/builder.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn as nn -import re - -from .pooler_projector import PoolerProjector - - -class IdentityMap(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, *args, **kwargs): - return x - - @property - def config(self): - return {"mm_projector_type": "identity"} - - -class SimpleResBlock(nn.Module): - def __init__(self, channels): - super().__init__() - self.pre_norm = nn.LayerNorm(channels) - - self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) - - def forward(self, x): - x = self.pre_norm(x) - return x + self.proj(x) - - -def build_vision_projector(config, delay_load=False, **kwargs): - projector_type = getattr(config, "mm_projector_type", "linear") - - if projector_type == "linear": - return nn.Linear(config.mm_hidden_size, config.hidden_size) - - if projector_type == "pooler": - return PoolerProjector(config, kwargs["vision_cfg"]) - - mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) - if mlp_gelu_match: - mlp_depth = int(mlp_gelu_match.group(1)) - modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] - for _ in range(1, mlp_depth): - modules.append(nn.GELU()) - modules.append(nn.Linear(config.hidden_size, config.hidden_size)) - return nn.Sequential(*modules) - - mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) - if mlp_gelu_resnet_match: - mlp_depth = int(mlp_gelu_resnet_match.group(1)) - res_depth = int(mlp_gelu_resnet_match.group(2)) - modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] - for _ in range(1, mlp_depth): - modules.append(nn.GELU()) - modules.append(nn.Linear(config.hidden_size, config.hidden_size)) - for _ in range(res_depth): - modules.append(SimpleResBlock(config.hidden_size)) - return nn.Sequential(*modules) - - if projector_type == "identity": - return IdentityMap() - - raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/llava/model/multimodal_projector/pooler_projector.py b/llava/model/multimodal_projector/pooler_projector.py deleted file mode 100644 index ce5a2e05fa44ad2978272aea6dcf0aa9ca135e55..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_projector/pooler_projector.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.nn as nn - -import math - -from transformers.models.clip.modeling_clip import CLIPVisionModel - - -class PoolerProjector(nn.Module): - def __init__(self, config, vision_cfg): - super().__init__() - self._config = config - self.hw = vision_cfg.image_size // vision_cfg.patch_size - - self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) - - self.proj = nn.Sequential( - nn.GELU(), - nn.Linear(config.hidden_size, config.hidden_size), - ) - - def forward(self, x, *args, **kwargs): - height = width = self.hw - assert height * width == x.shape[1] - x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) - x = self.conv_pool(x) - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) - return x - - @property - def config(self): - return {"mm_projector_type": "pooler"} diff --git a/llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc b/llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc deleted file mode 100644 index aee158db7487ab557faa491cc570905bce1fd93b..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc b/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc deleted file mode 100644 index 1c41df707a7d9736b35a79877ca69baaa405f649..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc b/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc deleted file mode 100644 index fe7343b3b93f094030f54a36d103824ddff2ae27..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc b/llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc deleted file mode 100644 index 8e9983efc93fb8877df80e221f72697dad0ff078..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc b/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc deleted file mode 100644 index fd86cb55814fd0230cd94372c64bae131e004f9e..0000000000000000000000000000000000000000 Binary files a/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc and /dev/null differ diff --git a/llava/model/multimodal_resampler/builder.py b/llava/model/multimodal_resampler/builder.py deleted file mode 100644 index 7a4b207f3bded33b89ddef3899233c3825d91701..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_resampler/builder.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - -from .masked_drop import MaskedDrop -from .spatial_pool import SpatialPool -from .perceiver import PerceiverResampler -from .qformer import Qformer - - -class IdentityMap(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, *args, **kwargs): - return x - - @property - def config(self): - return {"mm_resampler_type": None} - - -def build_vision_resampler(model_args, delay_load=False, **kwargs): - resampler_type = getattr(model_args, "mm_resampler_type", None) - if resampler_type == "masked_drop": - return MaskedDrop(model_args) - elif resampler_type == "spatial_pool": - return SpatialPool(model_args, **kwargs) - elif resampler_type == "perceiver": - return PerceiverResampler(model_args, **kwargs) - elif resampler_type == "qformer": - return Qformer(model_args, **kwargs) - elif resampler_type is None: - return IdentityMap() - - raise ValueError(f"Unknown resampler type: {resampler_type}") diff --git a/llava/model/multimodal_resampler/masked_drop.py b/llava/model/multimodal_resampler/masked_drop.py deleted file mode 100644 index 03f0bf0b259ff5d96adea4aad91ee5498d459030..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_resampler/masked_drop.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import torch.nn as nn - -import random - - -class MaskedDrop(nn.Module): - def __init__(self, model_args): - super().__init__() - - self.mode = model_args.mm_mask_drop_mode - self.skip_percentage = model_args.mm_mask_drop_skip_percentage - self.ratio = model_args.mm_mask_drop_ratio - self.ratio_upper = model_args.mm_mask_drop_ratio_upper - self.ratio_lower = model_args.mm_mask_drop_ratio_lower - - def forward(self, image_features, *args, **kwargs): - - if not self.training: - return image_features - - if self.skip_percentage > random.random(): - return image_features - - masked_features = [] - - for image_feature in image_features: - num_tokens = image_feature.shape[0] - if self.mode == "fixed": - num_keep = int(num_tokens * self.ratio) - masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) - elif self.mode == "range": - num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) - masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) - elif self.mode == "cls_only": - masked_features.append(image_feature[0:1]) - else: - raise ValueError(f"Unexpected masked drop mode: {self.mode}") - - if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): - masked_features = torch.stack(masked_features, dim=0) - - return masked_features - - @property - def config(self): - return { - "mm_resampler_type": "masked_drop", - "mm_mask_drop_mode": self.mode, - "mm_mask_drop_skip_percentage": self.skip_percentage, - "mm_mask_drop_ratio": self.ratio, - "mm_mask_drop_ratio_upper": self.ratio_upper, - "mm_mask_drop_ratio_lower": self.ratio_lower, - } - - def random_masking(self, x, len_keep): - """ - Perform per-sample random masking by per-sample shuffling. - Per-sample shuffling is done by argsort random noise. - x: [N, L, D], sequence - """ - N, L, D = x.shape # batch, length, dim - - noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - - # sort noise for each sample - ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove - ids_restore = torch.argsort(ids_shuffle, dim=1) - - # keep the first subset - ids_keep = ids_shuffle[:, :len_keep] - x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) - - # generate the binary mask: 0 is keep, 1 is remove - mask = torch.ones([N, L], device=x.device) - mask[:, :len_keep] = 0 - # unshuffle to get the binary mask - mask = torch.gather(mask, dim=1, index=ids_restore) - - return x_masked, mask, ids_restore diff --git a/llava/model/multimodal_resampler/perceiver.py b/llava/model/multimodal_resampler/perceiver.py deleted file mode 100644 index d6b17a559b2225832c7d87c4fb6894617779b9c1..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_resampler/perceiver.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Taken from https://github.com/lucidrains/flamingo-pytorch -""" - -import torch -from einops import rearrange, repeat - -try: - from einops_exts import rearrange_many -except: - pass - -from torch import einsum, nn - - -def exists(val): - return val is not None - - -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.heads = heads - inner_dim = dim_head * heads - - self.norm_media = nn.LayerNorm(dim) - self.norm_latents = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents): - """ - Args: - x (torch.Tensor): image features - shape (b, T, n1, D) - latent (torch.Tensor): latent features - shape (b, T, n2, D) - """ - x = self.norm_media(x) - latents = self.norm_latents(latents) - - h = self.heads - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) - q = q * self.scale - - # attention - sim = einsum("... i d, ... j d -> ... i j", q, k) - sim = sim - sim.amax(dim=-1, keepdim=True).detach() - attn = sim.softmax(dim=-1) - - out = einsum("... i j, ... j d -> ... i d", attn, v) - out = rearrange(out, "b h t n d -> b t n (h d)", h=h) - return self.to_out(out) - - -class PerceiverResamplerModule(nn.Module): - def __init__( - self, - *, - dim, - depth=6, - dim_head=64, - heads=8, - num_latents=64, - max_num_media=None, - max_num_frames=None, - ff_mult=4, - ): - super().__init__() - self.latents = nn.Parameter(torch.randn(num_latents, dim)) - self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None - self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), - ] - ) - ) - - self.norm = nn.LayerNorm(dim) - - def forward(self, x): - """ - Args: - x (torch.Tensor): image features - shape (b, T, F, v, D) - Returns: - shape (b, T, n, D) where n is self.num_latents - """ - b, T, F, v = x.shape[:4] - - # frame and media time embeddings - if exists(self.frame_embs): - frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) - x = x + frame_embs - x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions - if exists(self.media_time_embs): - x = x + self.media_time_embs[:T] - - # blocks - latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) - for attn, ff in self.layers: - latents = attn(x, latents) + latents - latents = ff(latents) + latents - return self.norm(latents) - - -class PerceiverResampler(nn.Module): - def __init__(self, model_args, vision_tower): - super().__init__() - - self.depth = model_args.mm_perceiver_depth - self.num_latents = model_args.mm_perceiver_latents - self.ff_mult = model_args.mm_perceiver_ff_mult - self.pretrained = model_args.mm_perceiver_pretrained - - self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) - - if self.pretrained is not None: - self.load_state_dict(torch.load(self.pretrained)) - - def forward(self, image_features, *args, **kwargs): - return self.perceiver(image_features[:, None, None]).squeeze(1) - - @property - def config(self): - return { - "mm_resampler_type": "perceiver", - "mm_perceiver_depth": self.depth, - "mm_perceiver_latents": self.num_latents, - "mm_perceiver_ff_mult": self.ff_mult, - "mm_perceiver_pretrained": self.pretrained, - } diff --git a/llava/model/multimodal_resampler/qformer.py b/llava/model/multimodal_resampler/qformer.py deleted file mode 100644 index b86754c24adbfa5ce34e37ee4726c74e3b7f910f..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_resampler/qformer.py +++ /dev/null @@ -1,1160 +0,0 @@ -""" - * Copyright (c) 2023, salesforce.com, inc. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause - * By Junnan Li - * Based on huggingface code base - * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert -""" - -import math -import os -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple, Dict, Any - -import torch -from torch import Tensor, device, dtype, nn -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -import torch.nn.functional as F - -from transformers.activations import ACT2FN -from transformers.file_utils import ( - ModelOutput, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - MaskedLMOutput, - MultipleChoiceModelOutput, - NextSentencePredictorOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) -from transformers.modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from transformers.utils import logging -from transformers.models.bert.configuration_bert import BertConfig - -logger = logging.get_logger(__name__) - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word and position embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - - self.config = config - - def forward( - self, - input_ids=None, - position_ids=None, - query_embeds=None, - past_key_values_length=0, - ): - if input_ids is not None: - seq_length = input_ids.size()[1] - else: - seq_length = 0 - - if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() - - if input_ids is not None: - embeddings = self.word_embeddings(input_ids) - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - - if query_embeds is not None: - embeddings = torch.cat((query_embeds, embeddings), dim=1) - else: - embeddings = query_embeds - - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, config, is_cross_attention): - super().__init__() - self.config = config - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - if is_cross_attention: - self.key = nn.Linear(config.encoder_width, self.all_head_size) - self.value = nn.Linear(config.encoder_width, self.all_head_size) - else: - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - self.save_attention = False - - def save_attn_gradients(self, attn_gradients): - self.attn_gradients = attn_gradients - - def get_attn_gradients(self): - return self.attn_gradients - - def save_attention_map(self, attention_map): - self.attention_map = attention_map - - def get_attention_map(self): - return self.attention_map - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - mixed_query_layer = self.query(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - if is_cross_attention and self.save_attention: - self.save_attention_map(attention_probs) - attention_probs.register_hook(self.save_attn_gradients) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs_dropped = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs_dropped = attention_probs_dropped * head_mask - - context_layer = torch.matmul(attention_probs_dropped, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs - - -class BertSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, config, is_cross_attention=False): - super().__init__() - self.self = BertSelfAttention(config, is_cross_attention) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, - self.self.num_attention_heads, - self.self.attention_head_size, - self.pruned_heads, - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class BertIntermediate(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, config, layer_num): - super().__init__() - self.config = config - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = BertAttention(config) - self.layer_num = layer_num - if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0: - self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) - self.has_cross_attention = True - else: - self.has_cross_attention = False - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - self.intermediate_query = BertIntermediate(config) - self.output_query = BertOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - query_length=0, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:-1] - - present_key_value = self_attention_outputs[-1] - - if query_length > 0: - query_attention_output = attention_output[:, :query_length, :] - - if self.has_cross_attention: - assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" - cross_attention_outputs = self.crossattention( - query_attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions=output_attentions, - ) - query_attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk_query, - self.chunk_size_feed_forward, - self.seq_len_dim, - query_attention_output, - ) - if attention_output.shape[1] > query_length: - layer_output_text = apply_chunking_to_forward( - self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output[:, query_length:, :], - ) - layer_output = torch.cat([layer_output, layer_output_text], dim=1) - else: - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output, - ) - outputs = (layer_output,) + outputs - - outputs = outputs + (present_key_value,) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - def feed_forward_chunk_query(self, attention_output): - intermediate_output = self.intermediate_query(attention_output) - layer_output = self.output_query(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - query_length=0, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - - for i in range(self.config.num_hidden_layers): - layer_module = self.layer[i] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - if getattr(self.config, "gradient_checkpointing", False) and self.training: - - if use_cache: - logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class BertPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BertConfig - base_model_prefix = "bert" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -class BertModel(BertPreTrainedModel): - """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an - input to the forward pass. - """ - - def __init__(self, config, add_pooling_layer=False): - super().__init__(config) - self.config = config - - self.embeddings = BertEmbeddings(config) - - self.encoder = BertEncoder(config) - - self.pooler = BertPooler(config) if add_pooling_layer else None - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def get_extended_attention_mask( - self, - attention_mask: Tensor, - input_shape: Tuple[int], - device: device, - is_decoder: bool, - has_query: bool = False, - ) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (:obj:`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (:obj:`Tuple[int]`): - The shape of the input to the model. - device: (:obj:`torch.device`): - The device of the input to the model. - - Returns: - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if is_decoder: - batch_size, seq_length = input_shape - - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - - # add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - if has_query: # UniLM style attention mask - causal_mask = torch.cat( - [ - torch.zeros( - (batch_size, prefix_seq_len, seq_length), - device=device, - dtype=causal_mask.dtype, - ), - causal_mask, - ], - axis=1, - ) - causal_mask = torch.cat( - [ - torch.ones( - (batch_size, causal_mask.shape[1], prefix_seq_len), - device=device, - dtype=causal_mask.dtype, - ), - causal_mask, - ], - axis=-1, - ) - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - is_decoder=False, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is None: - assert query_embeds is not None, "You have to specify query_embeds when input_ids is None" - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 - - query_length = query_embeds.shape[1] if query_embeds is not None else 0 - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - query_embeds=query_embeds, - past_key_values_length=past_key_values_length, - ) - - input_shape = embedding_output.size()[:-1] - batch_size, seq_length = input_shape - device = embedding_output.device - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if is_decoder: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, - input_ids.shape, - device, - is_decoder, - has_query=(query_embeds is not None), - ) - else: - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if encoder_hidden_states is not None: - if type(encoder_hidden_states) == list: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() - else: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - - if type(encoder_attention_mask) == list: - encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] - elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - query_length=query_length, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - -class BertLMHeadModel(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=True, - reduction="mean", - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - Returns: - Example:: - >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig - >>> import torch - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - >>> config = BertConfig.from_pretrained("bert-base-cased") - >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - >>> prediction_logits = outputs.logits - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - if past_key_values is not None: - query_embeds = None - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - query_embeds=query_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - ) - - sequence_output = outputs[0] - if query_embeds is not None: - sequence_output = outputs[0][:, query_embeds.shape[1] :, :] - - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores[:, :-1, :].contiguous() - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) - lm_loss = loss_fct( - shifted_prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1), - ) - if reduction == "none": - lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_ids.shape) - query_mask = input_ids.new_ones(query_embeds.shape[:-1]) - attention_mask = torch.cat([query_mask, attention_mask], dim=-1) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "query_embeds": query_embeds, - "attention_mask": attention_mask, - "past_key_values": past, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - -class BertForMaskedLM(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=False, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - query_embeds=query_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - ) - - if query_embeds is not None: - sequence_output = outputs[0][:, query_embeds.shape[1] :, :] - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class Qformer(nn.Module): - def __init__(self, model_args, vision_tower): - super().__init__() - - self.depth = model_args.mm_qformer_depth - self.num_latents = model_args.mm_qformer_latents - self.pretrained = model_args.mm_qformer_pretrained - - self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents) - - if self.pretrained is not None: - pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] - pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")} - self.load_state_dict(pretrained_dict) - - def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") - encoder_config.encoder_width = vision_width - # insert cross-attention layer every other block - encoder_config.add_cross_attention = True - encoder_config.cross_attention_freq = cross_attention_freq - encoder_config.query_length = num_query_token - Qformer = BertLMHeadModel(config=encoder_config) - query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) - query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) - Qformer.cls = None - Qformer.bert.embeddings.word_embeddings = None - Qformer.bert.embeddings.position_embeddings = None - for layer in Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - return Qformer, query_tokens, nn.LayerNorm(vision_width) - - def forward(self, image_features, *args, **kwargs): - x = self.ln_vision(image_features) - image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) - - query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) - query_output = self.Qformer.bert( - query_embeds=query_tokens, - encoder_hidden_states=x, - encoder_attention_mask=image_atts, - return_dict=True, - ) - - return query_output.last_hidden_state - - @property - def hidden_size(self): - return 768 - - @property - def config(self): - return { - "mm_resampler_type": "qformer", - "mm_qformer_depth": self.depth, - "mm_qformer_latents": self.num_latents, - "mm_qformer_pretrained": self.pretrained, - } diff --git a/llava/model/multimodal_resampler/spatial_pool.py b/llava/model/multimodal_resampler/spatial_pool.py deleted file mode 100644 index 4bdbe3aecc91183341816c800c8ad1fcfba9a169..0000000000000000000000000000000000000000 --- a/llava/model/multimodal_resampler/spatial_pool.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torch.nn as nn -import math - - -class SpatialPool(nn.Module): - def __init__(self, model_args, vision_tower): - super().__init__() - - self.mode = model_args.mm_spatial_pool_mode - self.stride = model_args.mm_spatial_pool_stride - self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) - - if self.mode == "average": - self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) - elif self.mode == "max": - self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) - elif self.mode == "conv": - self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) - else: - raise ValueError(f"Unknown pooling mode: {self.pool}.") - - def forward(self, image_features, images, *args, **kwargs): - ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) - ori_H = int(ori_W * images.shape[2] // images.shape[3]) - - B, _, F = image_features.shape - - image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) - image_features_spatial_pool = self.pool(image_features_spatial) - - return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() - - @property - def config(self): - return { - "mm_resampler_type": "spatial_pool", - "mm_spatial_pool_stride": self.stride, - "mm_spatial_pool_mode": self.mode, - "mm_spatial_pool_out_channels": self.out_channels, - } - - @property - def hidden_size(self): - return self.out_channels diff --git a/llava/model/utils.py b/llava/model/utils.py deleted file mode 100644 index 10652a5f9aaa2e0cddaef0b1a7bc39013a0d957b..0000000000000000000000000000000000000000 --- a/llava/model/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from transformers import AutoConfig - - -def auto_upgrade(config): - cfg = AutoConfig.from_pretrained(config) - if "llava" in config and "llava" not in cfg.model_type: - assert cfg.model_type == "llama" - print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") - print("You must upgrade the checkpoint to the new code base (this can be done automatically).") - confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") - if confirm.lower() in ["y", "yes"]: - print("Upgrading checkpoint...") - assert len(cfg.architectures) == 1 - setattr(cfg.__class__, "model_type", "llava") - cfg.architectures[0] = "LlavaLlamaForCausalLM" - cfg.save_pretrained(config) - print("Checkpoint upgraded.") - else: - print("Checkpoint upgrade aborted.") - exit(1) diff --git a/llava/utils.py b/llava/utils.py deleted file mode 100644 index 2bf3bd19db050ead12bae844ac7a6e95517c74d6..0000000000000000000000000000000000000000 --- a/llava/utils.py +++ /dev/null @@ -1,134 +0,0 @@ -import datetime -import logging -import logging.handlers -import os -import sys - -import requests - -from llava.constants import LOGDIR - -server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." - -handler = None - -import torch.distributed as dist - - -def rank0_print(*args): - if dist.is_initialized(): - if dist.get_rank() == 0: - print(f"Rank {dist.get_rank()}: ", *args) - - -def build_logger(logger_name, logger_filename): - global handler - - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Set the format of root handlers - if not logging.getLogger().handlers: - logging.basicConfig(level=logging.INFO) - logging.getLogger().handlers[0].setFormatter(formatter) - - # Redirect stdout and stderr to loggers - stdout_logger = logging.getLogger("stdout") - stdout_logger.setLevel(logging.INFO) - sl = StreamToLogger(stdout_logger, logging.INFO) - sys.stdout = sl - - stderr_logger = logging.getLogger("stderr") - stderr_logger.setLevel(logging.ERROR) - sl = StreamToLogger(stderr_logger, logging.ERROR) - sys.stderr = sl - - # Get logger - logger = logging.getLogger(logger_name) - logger.setLevel(logging.INFO) - - # Add a file handler for all loggers - if handler is None: - os.makedirs(LOGDIR, exist_ok=True) - filename = os.path.join(LOGDIR, logger_filename) - handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True) - handler.setFormatter(formatter) - - for name, item in logging.root.manager.loggerDict.items(): - if isinstance(item, logging.Logger): - item.addHandler(handler) - - return logger - - -class StreamToLogger(object): - """ - Fake file-like stream object that redirects writes to a logger instance. - """ - - def __init__(self, logger, log_level=logging.INFO): - self.terminal = sys.stdout - self.logger = logger - self.log_level = log_level - self.linebuf = "" - - def __getattr__(self, attr): - return getattr(self.terminal, attr) - - def write(self, buf): - temp_linebuf = self.linebuf + buf - self.linebuf = "" - for line in temp_linebuf.splitlines(True): - # From the io.TextIOWrapper docs: - # On output, if newline is None, any '\n' characters written - # are translated to the system default line separator. - # By default sys.stdout.write() expects '\n' newlines and then - # translates them so this is still cross platform. - if line[-1] == "\n": - self.logger.log(self.log_level, line.rstrip()) - else: - self.linebuf += line - - def flush(self): - if self.linebuf != "": - self.logger.log(self.log_level, self.linebuf.rstrip()) - self.linebuf = "" - - -def disable_torch_init(): - """ - Disable the redundant torch default initialization to accelerate model creation. - """ - import torch - - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -def violates_moderation(text): - """ - Check whether the text violates OpenAI moderation API. - """ - url = "https://api.openai.com/v1/moderations" - headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} - text = text.replace("\n", "") - data = "{" + '"input": ' + f'"{text}"' + "}" - data = data.encode("utf-8") - try: - ret = requests.post(url, headers=headers, data=data, timeout=5) - flagged = ret.json()["results"][0]["flagged"] - except requests.exceptions.RequestException as e: - flagged = False - except KeyError as e: - flagged = False - - return flagged - - -def pretty_print_semaphore(semaphore): - if semaphore is None: - return "None" - return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a8cd6f6357b7ff5901c363a9b66e930f9e6e578 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +#git+https://github.com/LLaVA-VL/LLaVA-NeXT.git +opencv-python +open_clip_torch +fastapi +gradio==3.35.2 +markdown2[all] +numpy +requests +sentencepiece +torch==2.1.2 +torchvision==0.16.2 +uvicorn +wandb==0.16.5 +deepspeed==0.12.2 +peft==0.4.0 +accelerate>=0.29.1 +tokenizers~=0.15.2 +transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4 +bitsandbytes==0.41.0 +scikit-learn==1.2.2 +sentencepiece~=0.1.99 +einops==0.6.1 +einops-exts==0.0.4 +gradio_client==0.2.9 +pydantic==1.10.8 +timm +hf_transfer +decord +datasets==2.16.1 +tyro +scipy +rouge +urllib3<=2.0.0