chendl commited on
Commit
9fabf9a
1 Parent(s): 3280d5f

update cap

Browse files
app.py CHANGED
@@ -63,7 +63,7 @@ if "vision_encoder.logit_scale" in model_state_dict:
63
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
64
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
65
  flamingo.load_state_dict(model_state_dict, strict=True)
66
- chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size)
67
 
68
 
69
  def get_outputs(
 
63
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
64
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
65
  flamingo.load_state_dict(model_state_dict, strict=True)
66
+ chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size,model_name)
67
 
68
 
69
  def get_outputs(
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -276,18 +276,20 @@ def preprocess_image(sample, image_processor):
276
 
277
 
278
  class Chat:
279
- def __init__(self, model, vis_processor, tokenizer, vis_embed_size):
280
  self.model = model
281
  self.vis_processor = vis_processor
282
  self.tokenizer = tokenizer
283
  self.vis_embed_size = vis_embed_size
284
  self.conv = []
 
285
  # stop_words_ids = [torch.tensor([835]).to(self.device),
286
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
287
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
288
 
289
- def ask(self, text, conv, radio, model_name):
290
- if "pythiaS" in model_name:
 
291
  conv.append({
292
  "from": "human",
293
  "value": text,
@@ -363,6 +365,7 @@ class Chat:
363
  previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
364
  prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
365
  size = 224
 
366
  self.model.eval()
367
  # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
368
  # image_path = input("Please enter the image path: ")
 
276
 
277
 
278
  class Chat:
279
+ def __init__(self, model, vis_processor, tokenizer, vis_embed_size,model_name):
280
  self.model = model
281
  self.vis_processor = vis_processor
282
  self.tokenizer = tokenizer
283
  self.vis_embed_size = vis_embed_size
284
  self.conv = []
285
+ self.model_name = model_name
286
  # stop_words_ids = [torch.tensor([835]).to(self.device),
287
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
288
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
289
 
290
+ def ask(self, text, conv, radio):
291
+ name = self.model_name
292
+ if name=="pythiaS":
293
  conv.append({
294
  "from": "human",
295
  "value": text,
 
365
  previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
366
  prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
367
  size = 224
368
+ model_name = self.model_name
369
  self.model.eval()
370
  # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
371
  # image_path = input("Please enter the image path: ")