Spaces:
Running
on
Zero
Running
on
Zero
Update chatbot.py
Browse files- chatbot.py +22 -14
chatbot.py
CHANGED
@@ -35,19 +35,27 @@ model.to("cuda")
|
|
35 |
# Credit to merve for code of llava interleave qwen
|
36 |
|
37 |
def sample_frames(video_file, num_frames) :
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Path to example images
|
53 |
examples_path = os.path.dirname(__file__)
|
@@ -279,7 +287,7 @@ def model_inference(
|
|
279 |
|
280 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
|
281 |
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
|
282 |
-
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True
|
283 |
generated_text = ""
|
284 |
|
285 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
|
35 |
# Credit to merve for code of llava interleave qwen
|
36 |
|
37 |
def sample_frames(video_file, num_frames) :
|
38 |
+
try:
|
39 |
+
video = cv2.VideoCapture(video_file)
|
40 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
41 |
+
fps = int(video.get(cv2.CAP_PROP_FPS))
|
42 |
+
# extracts 5 images/sec of video
|
43 |
+
num_frames = ((total_frames//fps)*5)
|
44 |
+
interval = total_frames // num_frames
|
45 |
+
frames = []
|
46 |
+
for i in range(total_frames):
|
47 |
+
ret, frame = video.read()
|
48 |
+
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
49 |
+
if not ret:
|
50 |
+
continue
|
51 |
+
if i % interval == 0:
|
52 |
+
frames.append(pil_img)
|
53 |
+
video.release()
|
54 |
+
return frames
|
55 |
+
except:
|
56 |
+
frames=[]
|
57 |
+
return frames
|
58 |
+
|
59 |
|
60 |
# Path to example images
|
61 |
examples_path = os.path.dirname(__file__)
|
|
|
287 |
|
288 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
|
289 |
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
|
290 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True)
|
291 |
generated_text = ""
|
292 |
|
293 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|