Dabococo commited on
Commit
0da8218
1 Parent(s): f05317a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +571 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import spaces
4
+ import torch
5
+ import argparse
6
+ from transformers import AutoModel, AutoTokenizer
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from decord import VideoReader, cpu
10
+ import io
11
+ import os
12
+ import copy
13
+ import requests
14
+ import base64
15
+ import json
16
+ import traceback
17
+ import re
18
+ import modelscope_studio as mgr
19
+
20
+
21
+ # README, How to run demo on different devices
22
+
23
+ # For Nvidia GPUs.
24
+ # python web_demo_2.6.py --device cuda
25
+
26
+ # For Mac with MPS (Apple silicon or AMD GPUs).
27
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
28
+
29
+ # Argparser
30
+ parser = argparse.ArgumentParser(description='demo')
31
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
32
+ parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
33
+ args = parser.parse_args()
34
+ device = args.device
35
+ assert device in ['cuda', 'mps']
36
+
37
+ # Load model
38
+ model_path = 'openbmb/MiniCPM-V-2_6'
39
+ if 'int4' in model_path:
40
+ if device == 'mps':
41
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
42
+ exit()
43
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
44
+ else:
45
+ if False: #args.multi_gpus:
46
+ from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
47
+ with init_empty_weights():
48
+ #model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
49
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
50
+ device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
51
+ no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
52
+ device_id = device_map["llm.model.embed_tokens"]
53
+ device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
54
+ device_map["vpm"] = device_id
55
+ device_map["resampler"] = device_id
56
+ device_id2 = device_map["llm.model.layers.26"]
57
+ device_map["llm.model.layers.8"] = device_id2
58
+ device_map["llm.model.layers.9"] = device_id2
59
+ device_map["llm.model.layers.10"] = device_id2
60
+ device_map["llm.model.layers.11"] = device_id2
61
+ device_map["llm.model.layers.12"] = device_id2
62
+ device_map["llm.model.layers.13"] = device_id2
63
+ device_map["llm.model.layers.14"] = device_id2
64
+ device_map["llm.model.layers.15"] = device_id2
65
+ device_map["llm.model.layers.16"] = device_id2
66
+ #print(device_map)
67
+
68
+ #model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
69
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=device_map)
70
+ else:
71
+ #model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
72
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
73
+ model = model.to(device=device)
74
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
75
+ model.eval()
76
+
77
+
78
+
79
+
80
+ ERROR_MSG = "Error, please retry"
81
+ model_name = 'MiniCPM-V 2.6'
82
+ MAX_NUM_FRAMES = 64
83
+ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
84
+ VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
85
+
86
+ def get_file_extension(filename):
87
+ return os.path.splitext(filename)[1].lower()
88
+
89
+ def is_image(filename):
90
+ return get_file_extension(filename) in IMAGE_EXTENSIONS
91
+
92
+ def is_video(filename):
93
+ return get_file_extension(filename) in VIDEO_EXTENSIONS
94
+
95
+
96
+ form_radio = {
97
+ 'choices': ['Beam Search', 'Sampling'],
98
+ #'value': 'Beam Search',
99
+ 'value': 'Sampling',
100
+ 'interactive': True,
101
+ 'label': 'Decode Type'
102
+ }
103
+
104
+
105
+ def create_component(params, comp='Slider'):
106
+ if comp == 'Slider':
107
+ return gr.Slider(
108
+ minimum=params['minimum'],
109
+ maximum=params['maximum'],
110
+ value=params['value'],
111
+ step=params['step'],
112
+ interactive=params['interactive'],
113
+ label=params['label']
114
+ )
115
+ elif comp == 'Radio':
116
+ return gr.Radio(
117
+ choices=params['choices'],
118
+ value=params['value'],
119
+ interactive=params['interactive'],
120
+ label=params['label']
121
+ )
122
+ elif comp == 'Button':
123
+ return gr.Button(
124
+ value=params['value'],
125
+ interactive=True
126
+ )
127
+
128
+
129
+ def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
130
+ return mgr.MultimodalInput(value=None, upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
131
+ upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
132
+ submit_button_props={'label': 'Submit'})
133
+
134
+
135
+ @spaces.GPU(duration=120)
136
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
137
+ try:
138
+ if msgs[-1]['role'] == 'assistant':
139
+ msgs = msgs[:-1] # remove last which is added for streaming
140
+ print('msgs:', msgs)
141
+ answer = model.chat(
142
+ image=None,
143
+ msgs=msgs,
144
+ tokenizer=tokenizer,
145
+ **params
146
+ )
147
+ if params['stream'] is False:
148
+ res = re.sub(r'(<box>.*</box>)', '', answer)
149
+ res = res.replace('<ref>', '')
150
+ res = res.replace('</ref>', '')
151
+ res = res.replace('<box>', '')
152
+ answer = res.replace('</box>', '')
153
+ print('answer:')
154
+ for char in answer:
155
+ print(char, flush=True, end='')
156
+ yield char
157
+ except Exception as e:
158
+ print(e)
159
+ traceback.print_exc()
160
+ yield ERROR_MSG
161
+
162
+
163
+ def encode_image(image):
164
+ if not isinstance(image, Image.Image):
165
+ if hasattr(image, 'path'):
166
+ image = Image.open(image.path).convert("RGB")
167
+ else:
168
+ image = Image.open(image.file.path).convert("RGB")
169
+ # resize to max_size
170
+ max_size = 448*16
171
+ if max(image.size) > max_size:
172
+ w,h = image.size
173
+ if w > h:
174
+ new_w = max_size
175
+ new_h = int(h * max_size / w)
176
+ else:
177
+ new_h = max_size
178
+ new_w = int(w * max_size / h)
179
+ image = image.resize((new_w, new_h), resample=Image.BICUBIC)
180
+ return image
181
+ ## save by BytesIO and convert to base64
182
+ #buffered = io.BytesIO()
183
+ #image.save(buffered, format="png")
184
+ #im_b64 = base64.b64encode(buffered.getvalue()).decode()
185
+ #return {"type": "image", "pairs": im_b64}
186
+
187
+
188
+ def encode_video(video):
189
+ def uniform_sample(l, n):
190
+ gap = len(l) / n
191
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
192
+ return [l[i] for i in idxs]
193
+
194
+ if hasattr(video, 'path'):
195
+ vr = VideoReader(video.path, ctx=cpu(0))
196
+ else:
197
+ vr = VideoReader(video.file.path, ctx=cpu(0))
198
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
199
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
200
+ if len(frame_idx)>MAX_NUM_FRAMES:
201
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
202
+ video = vr.get_batch(frame_idx).asnumpy()
203
+ video = [Image.fromarray(v.astype('uint8')) for v in video]
204
+ video = [encode_image(v) for v in video]
205
+ print('video frames:', len(video))
206
+ return video
207
+
208
+
209
+ def check_mm_type(mm_file):
210
+ if hasattr(mm_file, 'path'):
211
+ path = mm_file.path
212
+ else:
213
+ path = mm_file.file.path
214
+ if is_image(path):
215
+ return "image"
216
+ if is_video(path):
217
+ return "video"
218
+ return None
219
+
220
+
221
+ def encode_mm_file(mm_file):
222
+ if check_mm_type(mm_file) == 'image':
223
+ return [encode_image(mm_file)]
224
+ if check_mm_type(mm_file) == 'video':
225
+ return encode_video(mm_file)
226
+ return None
227
+
228
+ def make_text(text):
229
+ #return {"type": "text", "pairs": text} # # For remote call
230
+ return text
231
+
232
+ def encode_message(_question):
233
+ files = _question.files
234
+ question = _question.text
235
+ pattern = r"\[mm_media\]\d+\[/mm_media\]"
236
+ matches = re.split(pattern, question)
237
+ message = []
238
+ if len(matches) != len(files) + 1:
239
+ gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
240
+ assert len(matches) == len(files) + 1
241
+
242
+ text = matches[0].strip()
243
+ if text:
244
+ message.append(make_text(text))
245
+ for i in range(len(files)):
246
+ message += encode_mm_file(files[i])
247
+ text = matches[i + 1].strip()
248
+ if text:
249
+ message.append(make_text(text))
250
+ return message
251
+
252
+
253
+ def check_has_videos(_question):
254
+ images_cnt = 0
255
+ videos_cnt = 0
256
+ for file in _question.files:
257
+ if check_mm_type(file) == "image":
258
+ images_cnt += 1
259
+ else:
260
+ videos_cnt += 1
261
+ return images_cnt, videos_cnt
262
+
263
+
264
+ def count_video_frames(_context):
265
+ num_frames = 0
266
+ for message in _context:
267
+ for item in message["content"]:
268
+ #if item["type"] == "image": # For remote call
269
+ if isinstance(item, Image.Image):
270
+ num_frames += 1
271
+ return num_frames
272
+
273
+
274
+ def request(_question, _chat_bot, _app_cfg):
275
+ images_cnt = _app_cfg['images_cnt']
276
+ videos_cnt = _app_cfg['videos_cnt']
277
+ files_cnts = check_has_videos(_question)
278
+ if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
279
+ gr.Warning("Only supports single video file input right now!")
280
+ return _question, _chat_bot, _app_cfg
281
+ if files_cnts[1] + videos_cnt + files_cnts[0] + images_cnt <= 0:
282
+ gr.Warning("Please chat with at least one image or video.")
283
+ return _question, _chat_bot, _app_cfg
284
+ _chat_bot.append((_question, None))
285
+ images_cnt += files_cnts[0]
286
+ videos_cnt += files_cnts[1]
287
+ _app_cfg['images_cnt'] = images_cnt
288
+ _app_cfg['videos_cnt'] = videos_cnt
289
+ upload_image_disabled = videos_cnt > 0
290
+ upload_video_disabled = videos_cnt > 0 or images_cnt > 0
291
+ return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
292
+
293
+
294
+ def respond(_chat_bot, _app_cfg, params_form):
295
+ if len(_app_cfg) == 0:
296
+ yield (_chat_bot, _app_cfg)
297
+ elif _app_cfg['images_cnt'] == 0 and _app_cfg['videos_cnt'] == 0:
298
+ yield(_chat_bot, _app_cfg)
299
+ else:
300
+ _question = _chat_bot[-1][0]
301
+ _context = _app_cfg['ctx'].copy()
302
+ _context.append({'role': 'user', 'content': encode_message(_question)})
303
+
304
+ videos_cnt = _app_cfg['videos_cnt']
305
+
306
+ if params_form == 'Beam Search':
307
+ params = {
308
+ 'sampling': False,
309
+ 'stream': False,
310
+ 'num_beams': 3,
311
+ 'repetition_penalty': 1.2,
312
+ "max_new_tokens": 2048
313
+ }
314
+ else:
315
+ params = {
316
+ 'sampling': True,
317
+ 'stream': True,
318
+ 'top_p': 0.8,
319
+ 'top_k': 100,
320
+ 'temperature': 0.7,
321
+ 'repetition_penalty': 1.05,
322
+ "max_new_tokens": 2048
323
+ }
324
+ params["max_inp_length"] = 4352 # 4096+256
325
+
326
+ if videos_cnt > 0:
327
+ #params["max_inp_length"] = 4352 # 4096+256
328
+ params["use_image_id"] = False
329
+ params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
330
+
331
+ gen = chat("", _context, None, params)
332
+
333
+ _context.append({"role": "assistant", "content": [""]})
334
+ _chat_bot[-1][1] = ""
335
+
336
+ for _char in gen:
337
+ _chat_bot[-1][1] += _char
338
+ _context[-1]["content"][0] += _char
339
+ yield (_chat_bot, _app_cfg)
340
+
341
+ _app_cfg['ctx']=_context
342
+ yield (_chat_bot, _app_cfg)
343
+
344
+
345
+ def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
346
+ ctx = _app_cfg["ctx"]
347
+ message_item = []
348
+ if _image is not None:
349
+ image = Image.open(_image).convert("RGB")
350
+ ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
351
+ message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
352
+ _app_cfg["images_cnt"] += 1
353
+ else:
354
+ if _user_message:
355
+ ctx.append({"role": "user", "content": [make_text(_user_message)]})
356
+ message_item.append({"text": _user_message, "files": []})
357
+ else:
358
+ message_item.append(None)
359
+ if _assistant_message:
360
+ ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
361
+ message_item.append({"text": _assistant_message, "files": []})
362
+ else:
363
+ message_item.append(None)
364
+
365
+ _chat_bot.append(message_item)
366
+ return None, "", "", _chat_bot, _app_cfg
367
+
368
+
369
+ def fewshot_request(_image, _user_message, _chat_bot, _app_cfg):
370
+ if _app_cfg["images_cnt"] == 0 and not _image:
371
+ gr.Warning("Please chat with at least one image.")
372
+ return None, '', '', _chat_bot, _app_cfg
373
+ if _image:
374
+ _chat_bot.append([
375
+ {"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
376
+ ""
377
+ ])
378
+ _app_cfg["images_cnt"] += 1
379
+ else:
380
+ _chat_bot.append([
381
+ {"text": _user_message, "files": [_image]},
382
+ ""
383
+ ])
384
+
385
+ return None, '', '', _chat_bot, _app_cfg
386
+
387
+
388
+ def regenerate_button_clicked(_chat_bot, _app_cfg):
389
+ if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
390
+ gr.Warning('No question for regeneration.')
391
+ return None, None, '', '', _chat_bot, _app_cfg
392
+ if _app_cfg["chat_type"] == "Chat":
393
+ images_cnt = _app_cfg['images_cnt']
394
+ videos_cnt = _app_cfg['videos_cnt']
395
+ _question = _chat_bot[-1][0]
396
+ _chat_bot = _chat_bot[:-1]
397
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
398
+ files_cnts = check_has_videos(_question)
399
+ images_cnt -= files_cnts[0]
400
+ videos_cnt -= files_cnts[1]
401
+ _app_cfg['images_cnt'] = images_cnt
402
+ _app_cfg['videos_cnt'] = videos_cnt
403
+
404
+ _question, _chat_bot, _app_cfg = request(_question, _chat_bot, _app_cfg)
405
+ return _question, None, '', '', _chat_bot, _app_cfg
406
+ else:
407
+ last_message = _chat_bot[-1][0]
408
+ last_image = None
409
+ last_user_message = ''
410
+ if last_message.text:
411
+ last_user_message = last_message.text
412
+ if last_message.files:
413
+ last_image = last_message.files[0].file.path
414
+ _chat_bot[-1][1] = ""
415
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
416
+ return _question, None, '', '', _chat_bot, _app_cfg
417
+
418
+
419
+ def flushed():
420
+ return gr.update(interactive=True)
421
+
422
+
423
+ def clear(txt_message, chat_bot, app_session):
424
+ txt_message.files.clear()
425
+ txt_message.text = ''
426
+ chat_bot = copy.deepcopy(init_conversation)
427
+ app_session['sts'] = None
428
+ app_session['ctx'] = []
429
+ app_session['images_cnt'] = 0
430
+ app_session['videos_cnt'] = 0
431
+ return create_multimodal_input(), chat_bot, app_session, None, '', ''
432
+
433
+
434
+ def select_chat_type(_tab, _app_cfg):
435
+ _app_cfg["chat_type"] = _tab
436
+ return _app_cfg
437
+
438
+
439
+ init_conversation = [
440
+ [
441
+ None,
442
+ {
443
+ # The first message of bot closes the typewriter.
444
+ "text": "You can talk to me now",
445
+ "flushing": False
446
+ }
447
+ ],
448
+ ]
449
+
450
+
451
+ css = """
452
+ .example label { font-size: 16px;}
453
+ """
454
+
455
+ introduction = """
456
+
457
+ ## Features:
458
+ 1. Chat with single image
459
+ 2. Chat with multiple images
460
+ 3. Chat with video
461
+ 4. In-context few-shot learning
462
+
463
+ Click `How to use` tab to see examples.
464
+ """
465
+
466
+
467
+ with gr.Blocks(css=css) as demo:
468
+ with gr.Tab(model_name):
469
+ with gr.Row():
470
+ with gr.Column(scale=1, min_width=300):
471
+ gr.Markdown(value=introduction)
472
+ params_form = create_component(form_radio, comp='Radio')
473
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
474
+ clear_button = create_component({'value': 'Clear History'}, comp='Button')
475
+
476
+ with gr.Column(scale=3, min_width=500):
477
+ app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
478
+ chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=560, flushing=False, bubble_full_width=False)
479
+
480
+ with gr.Tab("Chat") as chat_tab:
481
+ txt_message = create_multimodal_input()
482
+ chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
483
+
484
+ txt_message.submit(
485
+ request,
486
+ [txt_message, chat_bot, app_session],
487
+ [txt_message, chat_bot, app_session]
488
+ ).then(
489
+ respond,
490
+ [chat_bot, app_session, params_form],
491
+ [chat_bot, app_session]
492
+ )
493
+
494
+ with gr.Tab("Few Shot") as fewshot_tab:
495
+ fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
496
+ with gr.Row():
497
+ with gr.Column(scale=1):
498
+ image_input = gr.Image(type="filepath", sources=["upload"])
499
+ with gr.Column(scale=3):
500
+ user_message = gr.Textbox(label="User")
501
+ assistant_message = gr.Textbox(label="Assistant")
502
+ with gr.Row():
503
+ add_demonstration_button = gr.Button("Add Example")
504
+ generate_button = gr.Button(value="Generate", variant="primary")
505
+ add_demonstration_button.click(
506
+ fewshot_add_demonstration,
507
+ [image_input, user_message, assistant_message, chat_bot, app_session],
508
+ [image_input, user_message, assistant_message, chat_bot, app_session]
509
+ )
510
+ generate_button.click(
511
+ fewshot_request,
512
+ [image_input, user_message, chat_bot, app_session],
513
+ [image_input, user_message, assistant_message, chat_bot, app_session]
514
+ ).then(
515
+ respond,
516
+ [chat_bot, app_session, params_form],
517
+ [chat_bot, app_session]
518
+ )
519
+
520
+ chat_tab.select(
521
+ select_chat_type,
522
+ [chat_tab_label, app_session],
523
+ [app_session]
524
+ )
525
+ chat_tab.select( # do clear
526
+ clear,
527
+ [txt_message, chat_bot, app_session],
528
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
529
+ )
530
+ fewshot_tab.select(
531
+ select_chat_type,
532
+ [fewshot_tab_label, app_session],
533
+ [app_session]
534
+ )
535
+ fewshot_tab.select( # do clear
536
+ clear,
537
+ [txt_message, chat_bot, app_session],
538
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
539
+ )
540
+ chat_bot.flushed(
541
+ flushed,
542
+ outputs=[txt_message]
543
+ )
544
+ regenerate.click(
545
+ regenerate_button_clicked,
546
+ [chat_bot, app_session],
547
+ [txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
548
+ ).then(
549
+ respond,
550
+ [chat_bot, app_session, params_form],
551
+ [chat_bot, app_session]
552
+ )
553
+ clear_button.click(
554
+ clear,
555
+ [txt_message, chat_bot, app_session],
556
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
557
+ )
558
+
559
+ with gr.Tab("How to use"):
560
+ with gr.Column():
561
+ with gr.Row():
562
+ image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
563
+ example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
564
+ example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
565
+
566
+
567
+ # launch
568
+ #demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")
569
+ demo.queue()
570
+ demo.launch(show_api=False)
571
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ transformers==4.40.2
5
+ sentencepiece==0.1.99
6
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.2/flash_attn-2.6.2+cu123torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
7
+ opencv-python==4.10.0.84
8
+ decord
9
+ #gradio==4.22.0
10
+ gradio==4.41.0
11
+ http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
12
+ accelerate
13
+ numpy==1.24.4