File size: 5,292 Bytes
473101c
 
1819f26
 
473101c
1819f26
473101c
 
 
 
 
 
 
 
 
 
 
 
2f11cc1
1819f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473101c
 
 
1819f26
 
 
473101c
 
1819f26
 
 
 
473101c
1819f26
473101c
1819f26
473101c
1819f26
473101c
1819f26
473101c
 
 
 
 
1819f26
 
473101c
 
 
 
 
 
 
 
1819f26
473101c
1819f26
 
473101c
 
 
 
1819f26
473101c
 
 
 
 
 
 
1819f26
 
473101c
1819f26
473101c
1819f26
 
 
473101c
1819f26
473101c
1819f26
473101c
1819f26
473101c
1819f26
473101c
 
 
 
 
1819f26
 
473101c
 
 
 
 
 
1819f26
 
 
473101c
1819f26
473101c
1819f26
 
 
b2557f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
import torch.nn as nn
import whisperx

clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
phi_model_name  = "microsoft/phi-2"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 640
phi_embed  = 2560
compute_type = "float32"
audio_batch_size = 16

class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
        
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)

# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
merged_model   = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load('./model_chkpt/ft_projection_layer.pth',map_location=torch.device(device)))
resblock.load_state_dict(torch.load('./model_chkpt/ft_projection_model.pth',map_location=torch.device(device)))

def model_generate_ans(img=None,img_audio=None,val_q=None):

    max_generate_length = 100
    val_combined_embeds = []
    
    with torch.no_grad():
    
        # image
        if img is not None:
            image_processed  = processor(images=img, return_tensors="pt").to(device)
            clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
            val_image_embeds = projection(clip_val_outputs)
            val_image_embeds = resblock(val_image_embeds).to(torch.float16)
            
            img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
            img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

            val_combined_embeds.append(val_image_embeds)
            val_combined_embeds.append(img_token_embeds)

        # audio
        if img_audio is not None:
            audio_result = audio_model.transcribe(img_audio)
            audio_text = ''
            for seg in audio_result['segments']:
                audio_text += seg['text']
            audio_text = audio_text.strip()
            audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            audio_embeds    = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
            val_combined_embeds.append(audio_embeds)
            
        # text question
        if len(val_q) != 0:
            val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            val_q_embeds    = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
            val_combined_embeds.append(val_q_embeds)

        val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
        
        #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
        predicted_caption = torch.full((1,max_generate_length),50256).to(device)
    
        for g in range(max_generate_length):
            phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
            predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
            predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
            predicted_caption[:,g] = predicted_word_token.view(1,-1)
            next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
            val_combined_embeds   = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
            
        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    
    return predicted_captions_decoded
    

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Chat with MultiModal GPT !
    Build using combining clip model and phi-2 model.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Image',type="pil")
            img_audio    = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
            img_question = gr.Text(label ='Text Query')
        with gr.Column():
            img_answer   = gr.Text(label ='Answer')

    section_btn = gr.Button("Submit")
    section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
    
demo.launch()