Spaces:
Runtime error
Runtime error
update chat
Browse files- app.py +8 -3
- multimodal/open_flamingo/chat/conversation.py +31 -7
app.py
CHANGED
@@ -248,12 +248,17 @@ def gradio_ask(user_message, chatbot, chat_state):
|
|
248 |
|
249 |
|
250 |
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
251 |
-
llm_message = \
|
252 |
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
|
253 |
-
max_length=2000)
|
254 |
|
255 |
chatbot[-1][1] = llm_message
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
259 |
with gr.Blocks() as demo:
|
|
|
248 |
|
249 |
|
250 |
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
251 |
+
llm_message,image = \
|
252 |
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
|
253 |
+
max_length=2000)
|
254 |
|
255 |
chatbot[-1][1] = llm_message
|
256 |
+
if image==None:
|
257 |
+
return chatbot, chat_state, img_list
|
258 |
+
else:
|
259 |
+
path = build_image(image)
|
260 |
+
chatbot = chatbot + [[(path,), None]]
|
261 |
+
|
262 |
|
263 |
|
264 |
with gr.Blocks() as demo:
|
multimodal/open_flamingo/chat/conversation.py
CHANGED
@@ -260,6 +260,12 @@ def preprocess_conv(data):
|
|
260 |
conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
|
261 |
return conversation
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
class Chat:
|
264 |
def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
|
265 |
self.model = model
|
@@ -322,6 +328,7 @@ class Chat:
|
|
322 |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
323 |
# image_path = input("Please enter the image path: ")
|
324 |
image = img_list[0].convert("RGB")
|
|
|
325 |
image = image.resize((size, size))
|
326 |
print(f"image size: {image.size}")
|
327 |
batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
@@ -370,14 +377,31 @@ class Chat:
|
|
370 |
image_start_index_list=image_start_index_list,
|
371 |
image_nums=image_nums,
|
372 |
)
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
print(
|
378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
def upload_img(self, image, conv, img_list):
|
383 |
img_list.append(image)
|
|
|
260 |
conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
|
261 |
return conversation
|
262 |
|
263 |
+
def preprocess_image(sample, image_processor):
|
264 |
+
image = image_processor(sample)
|
265 |
+
if isinstance(image, transformers.image_processing_utils.BatchFeature):
|
266 |
+
image = torch.tensor(image["pixel_values"][0])
|
267 |
+
return image
|
268 |
+
|
269 |
class Chat:
|
270 |
def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
|
271 |
self.model = model
|
|
|
328 |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
329 |
# image_path = input("Please enter the image path: ")
|
330 |
image = img_list[0].convert("RGB")
|
331 |
+
image_ori = image
|
332 |
image = image.resize((size, size))
|
333 |
print(f"image size: {image.size}")
|
334 |
batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
|
|
377 |
image_start_index_list=image_start_index_list,
|
378 |
image_nums=image_nums,
|
379 |
)
|
380 |
+
boxes = outputs["boxes"]
|
381 |
+
scores = outputs["scores"]
|
382 |
+
if len(scores) > 0:
|
383 |
+
box = boxes[scores.argmax()] / 224
|
384 |
+
print(f"{box}")
|
385 |
+
out_image = None
|
386 |
+
|
387 |
+
if len(boxes)>0:
|
388 |
+
open_cv_image = np.array(image_ori)
|
389 |
+
# Convert RGB to BGR
|
390 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
391 |
+
box = box * [width, height, width, height]
|
392 |
+
# for box in boxes:
|
393 |
+
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
394 |
+
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
|
395 |
+
|
396 |
|
397 |
+
# output_token = outputs[0, input_ids.shape[1]:]
|
398 |
+
# output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
|
399 |
+
# conv[-1]["value"] = output_text
|
400 |
+
# # conv.messages[-1][1] = output_text
|
401 |
+
# print(
|
402 |
+
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
403 |
+
output_text = "here"
|
404 |
+
return output_text, out_image
|
405 |
|
406 |
def upload_img(self, image, conv, img_list):
|
407 |
img_list.append(image)
|