Spaces:
Runtime error
Runtime error
update cap
Browse files- app.py +1 -1
- multimodal/open_flamingo/chat/conversation.py +6 -3
app.py
CHANGED
@@ -63,7 +63,7 @@ if "vision_encoder.logit_scale" in model_state_dict:
|
|
63 |
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
64 |
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
65 |
flamingo.load_state_dict(model_state_dict, strict=True)
|
66 |
-
chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size)
|
67 |
|
68 |
|
69 |
def get_outputs(
|
|
|
63 |
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
64 |
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
65 |
flamingo.load_state_dict(model_state_dict, strict=True)
|
66 |
+
chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size,model_name)
|
67 |
|
68 |
|
69 |
def get_outputs(
|
multimodal/open_flamingo/chat/conversation.py
CHANGED
@@ -276,18 +276,20 @@ def preprocess_image(sample, image_processor):
|
|
276 |
|
277 |
|
278 |
class Chat:
|
279 |
-
def __init__(self, model, vis_processor, tokenizer, vis_embed_size):
|
280 |
self.model = model
|
281 |
self.vis_processor = vis_processor
|
282 |
self.tokenizer = tokenizer
|
283 |
self.vis_embed_size = vis_embed_size
|
284 |
self.conv = []
|
|
|
285 |
# stop_words_ids = [torch.tensor([835]).to(self.device),
|
286 |
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
287 |
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
288 |
|
289 |
-
def ask(self, text, conv, radio
|
290 |
-
|
|
|
291 |
conv.append({
|
292 |
"from": "human",
|
293 |
"value": text,
|
@@ -363,6 +365,7 @@ class Chat:
|
|
363 |
previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
364 |
prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
365 |
size = 224
|
|
|
366 |
self.model.eval()
|
367 |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
368 |
# image_path = input("Please enter the image path: ")
|
|
|
276 |
|
277 |
|
278 |
class Chat:
|
279 |
+
def __init__(self, model, vis_processor, tokenizer, vis_embed_size,model_name):
|
280 |
self.model = model
|
281 |
self.vis_processor = vis_processor
|
282 |
self.tokenizer = tokenizer
|
283 |
self.vis_embed_size = vis_embed_size
|
284 |
self.conv = []
|
285 |
+
self.model_name = model_name
|
286 |
# stop_words_ids = [torch.tensor([835]).to(self.device),
|
287 |
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
288 |
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
289 |
|
290 |
+
def ask(self, text, conv, radio):
|
291 |
+
name = self.model_name
|
292 |
+
if name=="pythiaS":
|
293 |
conv.append({
|
294 |
"from": "human",
|
295 |
"value": text,
|
|
|
365 |
previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
366 |
prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
367 |
size = 224
|
368 |
+
model_name = self.model_name
|
369 |
self.model.eval()
|
370 |
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
371 |
# image_path = input("Please enter the image path: ")
|