Shashidhar226 commited on
Commit
0b0d3ca
1 Parent(s): cf65b89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -177
app.py CHANGED
@@ -6,184 +6,185 @@ import requests
6
  import torch
7
  import torchvision
8
 
9
- from langchain_google_genai import GoogleGenerativeAI
10
- from langchain_google_genai import ChatGoogleGenerativeAI
11
 
12
- from langchain.prompts import PromptTemplate
13
- from langchain.chains import LLMChain
14
- from langchain.chat_models import ChatOpenAI
15
 
16
- from transformers import AutoProcessor, AutoModelForCausalLM
17
- from huggingface_hub import hf_hub_download
18
 
19
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
20
- from transformers import BlipProcessor, BlipForConditionalGeneration
21
  import os
22
-
23
- # os.environ["OPENAI_API_KEY"] = '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'
24
- os.environ["GOOGLE_API_KEY"] = 'AIzaSyAsZTv6rUZq0TAh6yfmVCDA0tPIcGU3VxA'
25
-
26
- # llm = ChatOpenAI(temperature=0.2, model_name="gpt-3.5-turbo")
27
- llm = ChatGoogleGenerativeAI(temperature=0.2, model="gemini-pro")
28
-
29
- prompt = PromptTemplate(
30
- input_variables=["question", "elements"],
31
- template="""You are a helpful assistant that can answer question related to an image. You have the ability to see the image and answer questions about it.
32
- I will give you a question and element about the image and you will answer the question.
33
- \n\n
34
- #Question: {question}
35
- #Elements: {elements}
36
- \n\n
37
- Your structured response:""",
38
- )
39
-
40
- def convert_png_to_jpg(image):
41
- rgb_image = image.convert('RGB')
42
- byte_arr = BytesIO()
43
- rgb_image.save(byte_arr, format='JPEG')
44
- byte_arr.seek(0)
45
- return Image.open(byte_arr)
46
-
47
- def vilt(image, query):
48
- processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
49
- model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
50
- encoding = processor(image, query, return_tensors="pt")
51
- outputs = model(**encoding)
52
- logits = outputs.logits
53
- idx = logits.argmax(-1).item()
54
- sol = model.config.id2label[idx]
55
- return sol
56
-
57
- def blip(image, query):
58
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
59
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
60
- # unconditional image captioning
61
- inputs = processor(image, return_tensors="pt")
62
-
63
- out = model.generate(**inputs)
64
- sol = processor.decode(out[0], skip_special_tokens=True)
65
- return sol
66
-
67
- def GIT(image, query):
68
- processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
69
- model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
70
-
71
- # file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
72
- # image = Image.open(file_path).convert("RGB")
73
-
74
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
75
-
76
- question = query
77
-
78
- input_ids = processor(text=question, add_special_tokens=False).input_ids
79
- input_ids = [processor.tokenizer.cls_token_id] + input_ids
80
- input_ids = torch.tensor(input_ids).unsqueeze(0)
81
-
82
- generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
83
- response = processor.batch_decode(generated_ids, skip_special_tokens=True)
84
-
85
- generated_ids_1 = model.generate(pixel_values=pixel_values, max_length=50)
86
- generated_caption = processor.batch_decode(generated_ids_1, skip_special_tokens=True)[0]
87
-
88
- return response[0] + " " + generated_caption
89
-
90
- @st.cache_data(show_spinner="Processing image...")
91
- def generate_table(uploaded_file):
92
- image = Image.open(uploaded_file)
93
- print("graph start")
94
- model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
95
- processor = Pix2StructProcessor.from_pretrained('google/deplot')
96
- print("graph start 1")
97
- inputs = processor(images=image, text="Generate underlying data table of the figure below and give the text as well:", return_tensors="pt")
98
- predictions = model.generate(**inputs, max_new_tokens=512)
99
- print("end")
100
- table = processor.decode(predictions[0], skip_special_tokens=True)
101
- print(table)
102
- return table
103
-
104
- def process_query(image, query):
105
- blip_sol = blip(image, query)
106
- vilt_sol = vilt(image, query)
107
- GIT_sol = GIT(image, query)
108
- llm_sol = blip_sol + " " + vilt_sol + " " + GIT_sol
109
- print(llm_sol)
110
- chain = LLMChain(llm=llm, prompt=prompt)
111
- response = chain.run(question=query, elements=llm_sol)
112
- return response
113
-
114
- def process_query_graph(data_table, query):
115
- prompt = PromptTemplate(
116
- input_variables=["question", "elements"],
117
- template="""You are a helpful assistant capable of answering questions related to graph images.
118
- You possess the ability to view the graph image and respond to inquiries about it.
119
- I will provide you with a question and the associated data table of the graph, and you will answer the question
120
- \n\n
121
- #Question: {question}
122
- #Elements: {elements}
123
- \n\n
124
- Your structured response:""",
125
- )
126
- chain = LLMChain(llm=llm, prompt=prompt)
127
- response = chain.run(question=query, elements=data_table)
128
- return response
129
-
130
- def chart_with_Image():
131
- st.header("Chat with Image", divider='rainbow')
132
- uploaded_file = st.file_uploader('Upload your IMAGE', type=['png', 'jpeg', 'jpg'], key="imageUploader")
133
- if uploaded_file is not None:
134
- image = Image.open(uploaded_file)
135
-
136
- # ViLT model only supports JPG images
137
- if image.format == 'PNG':
138
- image = convert_png_to_jpg(image)
139
-
140
- st.image(image, caption='Uploaded Image.', width=300)
141
-
142
- cancel_button = st.button('Cancel')
143
- query = st.text_input('Ask a question to the IMAGE')
144
-
145
- if query:
146
- with st.spinner('Processing...'):
147
- answer = process_query(image, query)
148
- st.write(answer)
149
-
150
- if cancel_button:
151
- st.stop()
152
-
153
- def chat_with_graph():
154
- st.header("Chat with Graph", divider='rainbow')
155
- uploaded_file = st.file_uploader('Upload your GRAPH', type=['png', 'jpeg', 'jpg'], key="graphUploader")
156
-
157
- if uploaded_file is not None:
158
- image = Image.open(uploaded_file)
159
-
160
- # if image.format == 'PNG':
161
- # image = convert_png_to_jpg(image)
162
-
163
- # data_table = generate_table(uploaded_file)
164
-
165
- st.image(image, caption='Uploaded Image.')
166
- data_table = generate_table(uploaded_file)
167
- cancel_button = st.button('Cancel')
168
- query = st.text_input('Ask a question to the IMAGE')
169
- if query:
170
- with st.spinner('Processing...'):
171
- answer = process_query_graph(data_table, query)
172
- st.write(answer)
173
-
174
- if cancel_button:
175
- st.stop()
176
-
177
- st.title("Image Querying App ")
178
- option = st.selectbox(
179
- "Who would you like to chart with?",
180
- ("Image", "Graph"),
181
- index=None,
182
- placeholder="Select contact method...",
183
- )
184
-
185
- st.write('You selected:', option)
186
- if option == "Image":
187
- chart_with_Image()
188
- elif option == "Graph":
189
- chat_with_graph()
 
 
6
  import torch
7
  import torchvision
8
 
9
+ # from langchain_google_genai import GoogleGenerativeAI
10
+ # from langchain_google_genai import ChatGoogleGenerativeAI
11
 
12
+ # from langchain.prompts import PromptTemplate
13
+ # from langchain.chains import LLMChain
14
+ # from langchain.chat_models import ChatOpenAI
15
 
16
+ # from transformers import AutoProcessor, AutoModelForCausalLM
17
+ # from huggingface_hub import hf_hub_download
18
 
19
+ # from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
20
+ # from transformers import BlipProcessor, BlipForConditionalGeneration
21
  import os
22
+ print(os.getenv('GOOGLE_API_KEY'))
23
+
24
+ # # os.environ["OPENAI_API_KEY"] = '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'
25
+ # os.environ["GOOGLE_API_KEY"] = 'AIzaSyAsZTv6rUZq0TAh6yfmVCDA0tPIcGU3VxA'
26
+
27
+ # # llm = ChatOpenAI(temperature=0.2, model_name="gpt-3.5-turbo")
28
+ # llm = ChatGoogleGenerativeAI(temperature=0.2, model="gemini-pro")
29
+
30
+ # prompt = PromptTemplate(
31
+ # input_variables=["question", "elements"],
32
+ # template="""You are a helpful assistant that can answer question related to an image. You have the ability to see the image and answer questions about it.
33
+ # I will give you a question and element about the image and you will answer the question.
34
+ # \n\n
35
+ # #Question: {question}
36
+ # #Elements: {elements}
37
+ # \n\n
38
+ # Your structured response:""",
39
+ # )
40
+
41
+ # def convert_png_to_jpg(image):
42
+ # rgb_image = image.convert('RGB')
43
+ # byte_arr = BytesIO()
44
+ # rgb_image.save(byte_arr, format='JPEG')
45
+ # byte_arr.seek(0)
46
+ # return Image.open(byte_arr)
47
+
48
+ # def vilt(image, query):
49
+ # processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
50
+ # model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
51
+ # encoding = processor(image, query, return_tensors="pt")
52
+ # outputs = model(**encoding)
53
+ # logits = outputs.logits
54
+ # idx = logits.argmax(-1).item()
55
+ # sol = model.config.id2label[idx]
56
+ # return sol
57
+
58
+ # def blip(image, query):
59
+ # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
60
+ # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
61
+ # # unconditional image captioning
62
+ # inputs = processor(image, return_tensors="pt")
63
+
64
+ # out = model.generate(**inputs)
65
+ # sol = processor.decode(out[0], skip_special_tokens=True)
66
+ # return sol
67
+
68
+ # def GIT(image, query):
69
+ # processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
70
+ # model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
71
+
72
+ # # file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
73
+ # # image = Image.open(file_path).convert("RGB")
74
+
75
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values
76
+
77
+ # question = query
78
+
79
+ # input_ids = processor(text=question, add_special_tokens=False).input_ids
80
+ # input_ids = [processor.tokenizer.cls_token_id] + input_ids
81
+ # input_ids = torch.tensor(input_ids).unsqueeze(0)
82
+
83
+ # generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
84
+ # response = processor.batch_decode(generated_ids, skip_special_tokens=True)
85
+
86
+ # generated_ids_1 = model.generate(pixel_values=pixel_values, max_length=50)
87
+ # generated_caption = processor.batch_decode(generated_ids_1, skip_special_tokens=True)[0]
88
+
89
+ # return response[0] + " " + generated_caption
90
+
91
+ # @st.cache_data(show_spinner="Processing image...")
92
+ # def generate_table(uploaded_file):
93
+ # image = Image.open(uploaded_file)
94
+ # print("graph start")
95
+ # model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
96
+ # processor = Pix2StructProcessor.from_pretrained('google/deplot')
97
+ # print("graph start 1")
98
+ # inputs = processor(images=image, text="Generate underlying data table of the figure below and give the text as well:", return_tensors="pt")
99
+ # predictions = model.generate(**inputs, max_new_tokens=512)
100
+ # print("end")
101
+ # table = processor.decode(predictions[0], skip_special_tokens=True)
102
+ # print(table)
103
+ # return table
104
+
105
+ # def process_query(image, query):
106
+ # blip_sol = blip(image, query)
107
+ # vilt_sol = vilt(image, query)
108
+ # GIT_sol = GIT(image, query)
109
+ # llm_sol = blip_sol + " " + vilt_sol + " " + GIT_sol
110
+ # print(llm_sol)
111
+ # chain = LLMChain(llm=llm, prompt=prompt)
112
+ # response = chain.run(question=query, elements=llm_sol)
113
+ # return response
114
+
115
+ # def process_query_graph(data_table, query):
116
+ # prompt = PromptTemplate(
117
+ # input_variables=["question", "elements"],
118
+ # template="""You are a helpful assistant capable of answering questions related to graph images.
119
+ # You possess the ability to view the graph image and respond to inquiries about it.
120
+ # I will provide you with a question and the associated data table of the graph, and you will answer the question
121
+ # \n\n
122
+ # #Question: {question}
123
+ # #Elements: {elements}
124
+ # \n\n
125
+ # Your structured response:""",
126
+ # )
127
+ # chain = LLMChain(llm=llm, prompt=prompt)
128
+ # response = chain.run(question=query, elements=data_table)
129
+ # return response
130
+
131
+ # def chart_with_Image():
132
+ # st.header("Chat with Image", divider='rainbow')
133
+ # uploaded_file = st.file_uploader('Upload your IMAGE', type=['png', 'jpeg', 'jpg'], key="imageUploader")
134
+ # if uploaded_file is not None:
135
+ # image = Image.open(uploaded_file)
136
+
137
+ # # ViLT model only supports JPG images
138
+ # if image.format == 'PNG':
139
+ # image = convert_png_to_jpg(image)
140
+
141
+ # st.image(image, caption='Uploaded Image.', width=300)
142
+
143
+ # cancel_button = st.button('Cancel')
144
+ # query = st.text_input('Ask a question to the IMAGE')
145
+
146
+ # if query:
147
+ # with st.spinner('Processing...'):
148
+ # answer = process_query(image, query)
149
+ # st.write(answer)
150
+
151
+ # if cancel_button:
152
+ # st.stop()
153
+
154
+ # def chat_with_graph():
155
+ # st.header("Chat with Graph", divider='rainbow')
156
+ # uploaded_file = st.file_uploader('Upload your GRAPH', type=['png', 'jpeg', 'jpg'], key="graphUploader")
157
+
158
+ # if uploaded_file is not None:
159
+ # image = Image.open(uploaded_file)
160
+
161
+ # # if image.format == 'PNG':
162
+ # # image = convert_png_to_jpg(image)
163
+
164
+ # # data_table = generate_table(uploaded_file)
165
+
166
+ # st.image(image, caption='Uploaded Image.')
167
+ # data_table = generate_table(uploaded_file)
168
+ # cancel_button = st.button('Cancel')
169
+ # query = st.text_input('Ask a question to the IMAGE')
170
+ # if query:
171
+ # with st.spinner('Processing...'):
172
+ # answer = process_query_graph(data_table, query)
173
+ # st.write(answer)
174
+
175
+ # if cancel_button:
176
+ # st.stop()
177
+
178
+ # st.title("Image Querying App ")
179
+ # option = st.selectbox(
180
+ # "Who would you like to chart with?",
181
+ # ("Image", "Graph"),
182
+ # index=None,
183
+ # placeholder="Select contact method...",
184
+ # )
185
+
186
+ # st.write('You selected:', option)
187
+ # if option == "Image":
188
+ # chart_with_Image()
189
+ # elif option == "Graph":
190
+ # chat_with_graph()