File size: 4,979 Bytes
8ce49c8
 
0dc45ee
8ce49c8
 
be9b2ea
8ce49c8
 
 
0dc45ee
8ce49c8
be9b2ea
8ce49c8
be9b2ea
8ce49c8
 
0dc45ee
 
be9b2ea
 
0dc45ee
8ce49c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import spaces
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Lin-Chen/ShareCaptioner"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="cpu", torch_dtype=torch.float16, trust_remote_code=True).eval()
model.tokenizer = tokenizer

model.cuda()

seg1 = '<|User|>:'
seg2 = f'Analyze the image in a comprehensive and detailed manner.{model.eoh}\n<|Bot|>:'
seg_emb1 = model.encode_text(seg1, add_special_tokens=True).cuda()
seg_emb2 = model.encode_text(seg2, add_special_tokens=False).cuda()


@spaces.GPU
def detailed_caption(img_path):
    subs = []
    image = Image.open(img_path).convert("RGB")
    subs.append(model.vis_processor(image).unsqueeze(0))

    subs = torch.cat(subs, dim=0).cuda()
    tmp_bs = subs.shape[0]
    tmp_seg_emb1 = seg_emb1.repeat(tmp_bs, 1, 1)
    tmp_seg_emb2 = seg_emb2.repeat(tmp_bs, 1, 1)
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            subs = model.encode_img(subs)
            input_emb = torch.cat([tmp_seg_emb1, subs, tmp_seg_emb2], dim=1)
            out_embeds = model.internlm_model.generate(inputs_embeds=input_emb,
                                                       max_length=500,
                                                       num_beams=3,
                                                       min_length=1,
                                                       do_sample=True,
                                                       repetition_penalty=1.5,
                                                       length_penalty=1.0,
                                                       temperature=1.,
                                                       eos_token_id=model.tokenizer.eos_token_id,
                                                       num_return_sequences=1,
                                                       )

    return model.decode_text([out_embeds[0]])


block_css = """
#buttons button {
    min-width: min(120px,100%);
}
"""
title_markdown = ("""
# 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions
[[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/ShareGPT4Omni/ShareGPT4V)] | [[Paper](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/ShareGPT4V.pdf)]
""")
tos_markdown = ("""
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
""")
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
ack_markdown = ("""
### Acknowledgement
The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!
""")


def build_demo():
    with gr.Blocks(title="Share-Captioner", theme=gr.themes.Default(), css=block_css) as demo:
        gr.Markdown(title_markdown)

        with gr.Row():
            with gr.Column(scale=5):
                with gr.Row(elem_id="Model ID"):
                    gr.Dropdown(
                        choices=['Share-Captioner'],
                        value='Share-Captioner',
                        interactive=True,
                        label='Model ID',
                        container=False)
                img_path = gr.Image(label="Image", type="filepath")
            with gr.Column(scale=8):
                with gr.Row():
                    caption = gr.Textbox(label='Caption')
                with gr.Row():
                    submit_btn = gr.Button(
                        value="πŸš€ Generate", variant="primary")
                    regenerate_btn = gr.Button(value="πŸ”„ Regenerate")

        gr.Markdown(tos_markdown)
        gr.Markdown(learn_more_markdown)
        gr.Markdown(ack_markdown)

        submit_btn.click(detailed_caption, inputs=[
                         img_path], outputs=[caption])
        regenerate_btn.click(detailed_caption, inputs=[
                             img_path], outputs=[caption])

    return demo


if __name__ == '__main__':
    demo = build_demo()
    demo.launch()