File size: 1,873 Bytes
8273d5f
2a438ba
8273d5f
9ce3948
580cc25
 
695df9a
35c5f77
580cc25
 
 
 
 
9ce3948
 
 
 
 
 
 
 
 
 
 
 
a41d9f9
 
580cc25
9ce3948
580cc25
 
 
 
 
2a438ba
 
9ce3948
8273d5f
 
9ce3948
 
8273d5f
a41d9f9
8273d5f
580cc25
8273d5f
2a438ba
d652f80
 
 
8273d5f
a41d9f9
8273d5f
d652f80
8273d5f
35c5f77
d652f80
 
 
8273d5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import gradio as gr
from transformers import AutoModel, pipeline, AutoTokenizer
import spaces
import subprocess

# from issue: https://discuss.huggingface.co/t/how-to-install-flash-attention-on-hf-gradio-space/70698/2
# InternVL2 needs flash_attn
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)
try:
    model_name = "OpenGVLab/InternVL2-8B"
    # model: <class 'transformers_modules.OpenGVLab.InternVL2-8B.0e6d592d957d9739b6df0f4b90be4cb0826756b9.modeling_internvl_chat.InternVLChatModel'>
    model = (
        AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            # low_cpu_mem_usage=True,
            trust_remote_code=True,
        )
        .cuda()
        .eval()
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # pipeline: <class 'transformers.pipelines.visual_question_answering.VisualQuestionAnsweringPipeline'>
    inference = pipeline(
        task="visual-question-answering", model=model, tokenizer=tokenizer
    )
except Exception as error:
    raise gr.Error("👌" + str(error), duration=30)


@spaces.GPU
def predict(input_img, questions):
    try:
        gr.Info("pipeline: " + str(type(inference)))
        gr.Info("model: " + str(type(model)))
        predictions = inference(question=questions, image=input_img)
        return str(predictions)
    except Exception as e:
        error_message = "❌" + str(e)
        raise gr.Error(error_message, duration=25)


gradio_app = gr.Interface(
    predict,
    inputs=[
        gr.Image(label="Select A Image", sources=["upload", "webcam"], type="pil"),
        "text",
    ],
    outputs="text",
    title='ask me anything',
)

if __name__ == "__main__":
    gradio_app.launch(show_error=True, debug=True)