qsitj commited on
Commit
85114f8
·
verified ·
1 Parent(s): 38d0171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from PIL import Image
5
- import matplotlib.pyplot as plt
6
  from transformers import pipeline
7
  import torch
8
  from random import choice
9
- from io import BytesIO
10
  import os
11
  from datetime import datetime
12
 
@@ -21,41 +19,43 @@ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
21
 
22
  fdic = {
23
  "style": "italic",
24
- "size": 15,
25
  "color": "yellow",
26
  "weight": "bold"
27
  }
28
 
 
 
29
  def query_data(in_pil_img: Image.Image):
30
  results = detector(in_pil_img)
31
  # print(f"检测结果:{results}")
32
  return results
33
 
34
  def get_annotated_image(in_pil_img):
35
- plt.figure(figsize=(16, 10))
36
- plt.imshow(in_pil_img)
37
- ax = plt.gca()
38
  in_results = query_data(in_pil_img)
39
 
40
  for prediction in in_results:
41
- color = choice(COLORS)
42
  box = prediction['box']
43
  label = prediction['label']
44
  score = round(prediction['score'] * 100, 1)
 
 
 
 
 
 
 
 
45
 
46
- ax.add_patch(plt.Rectangle((box['xmin'], box['ymin']),
47
- box['xmax'] - box['xmin'],
48
- box['ymax'] - box['ymin'],
49
- fill=False, color=color, linewidth=3))
50
- ax.text(box['xmin'], box['ymin'], f"{label}: {score}%", fontdict=fdic)
51
 
52
- plt.axis("off")
53
- buf = BytesIO()
54
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
55
- plt.close() # 关闭图形以释放内存
56
- buf.seek(0)
57
- annotated_image = Image.open(buf).convert('RGB')
58
- return np.array(annotated_image)
59
 
60
  def process_video(input_video_path):
61
  cap = cv2.VideoCapture(input_video_path)
@@ -74,7 +74,7 @@ def process_video(input_video_path):
74
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
75
  output_video_filename = f"output_{timestamp}.mp4"
76
  output_video_path = os.path.join(output_dir, output_video_filename)
77
-
78
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
79
 
80
  while True:
@@ -84,9 +84,10 @@ def process_video(input_video_path):
84
 
85
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
  pil_image = Image.fromarray(rgb_frame)
 
87
  annotated_frame = get_annotated_image(pil_image)
88
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
89
-
90
  # 确保帧的尺寸与视频输出一致
91
  if bgr_frame.shape[:2] != (height, width):
92
  bgr_frame = cv2.resize(bgr_frame, (width, height))
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from PIL import Image,ImageDraw
 
5
  from transformers import pipeline
6
  import torch
7
  from random import choice
 
8
  import os
9
  from datetime import datetime
10
 
 
19
 
20
  fdic = {
21
  "style": "italic",
22
+ "size": 16,
23
  "color": "yellow",
24
  "weight": "bold"
25
  }
26
 
27
+ label_color_dict = {}
28
+
29
  def query_data(in_pil_img: Image.Image):
30
  results = detector(in_pil_img)
31
  # print(f"检测结果:{results}")
32
  return results
33
 
34
  def get_annotated_image(in_pil_img):
35
+ draw = ImageDraw.Draw(in_pil_img)
 
 
36
  in_results = query_data(in_pil_img)
37
 
38
  for prediction in in_results:
 
39
  box = prediction['box']
40
  label = prediction['label']
41
  score = round(prediction['score'] * 100, 1)
42
+ if score < 50:
43
+ continue # 过滤掉低置信度的预测结果
44
+
45
+ if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
46
+ color = choice(COLORS)
47
+ label_color_dict[label] = color
48
+ else:
49
+ color = label_color_dict[label]
50
 
51
+ # 绘制矩形
52
+ draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3)
53
+
54
+ # 添加文本
55
+ draw.text((box['xmin'], box['ymin']), f"{label}: {score}%", fill=color, fontdict=fdic)
56
 
57
+ # 返回的是原始图像对象,它已经被修改了
58
+ return np.array(in_pil_img.convert('RGB'))
 
 
 
 
 
59
 
60
  def process_video(input_video_path):
61
  cap = cv2.VideoCapture(input_video_path)
 
74
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
75
  output_video_filename = f"output_{timestamp}.mp4"
76
  output_video_path = os.path.join(output_dir, output_video_filename)
77
+ # print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
78
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
79
 
80
  while True:
 
84
 
85
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
  pil_image = Image.fromarray(rgb_frame)
87
+ # print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
88
  annotated_frame = get_annotated_image(pil_image)
89
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
90
+ # print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
91
  # 确保帧的尺寸与视频输出一致
92
  if bgr_frame.shape[:2] != (height, width):
93
  bgr_frame = cv2.resize(bgr_frame, (width, height))