qsitj commited on
Commit
fd08bcf
·
verified ·
1 Parent(s): aae3e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from PIL import Image,ImageDraw
5
  from transformers import pipeline
6
  import torch
7
  from random import choice
@@ -28,9 +28,27 @@ label_color_dict = {}
28
 
29
  def query_data(in_pil_img: Image.Image):
30
  results = detector(in_pil_img)
31
- # print(f"检测结果:{results}")
32
  return results
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_annotated_image(in_pil_img):
35
  draw = ImageDraw.Draw(in_pil_img)
36
  in_results = query_data(in_pil_img)
@@ -40,23 +58,34 @@ def get_annotated_image(in_pil_img):
40
  label = prediction['label']
41
  score = round(prediction['score'] * 100, 1)
42
  if score < 50:
43
- continue # 过滤掉低置信度的预测结果
44
 
45
- if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
46
  color = choice(COLORS)
47
  label_color_dict[label] = color
48
  else:
49
  color = label_color_dict[label]
50
 
 
 
 
 
 
 
 
 
 
51
  # 绘制矩形
52
  draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3)
53
-
54
  # 添加文本
55
- draw.text((box['xmin'], box['ymin']), f"{label}: {score}%", fill=color, fontdict=fdic)
 
56
 
57
  # 返回的是原始图像对象,它已经被修改了
58
  return np.array(in_pil_img.convert('RGB'))
59
 
 
60
  def process_video(input_video_path):
61
  cap = cv2.VideoCapture(input_video_path)
62
  if not cap.isOpened():
@@ -74,7 +103,7 @@ def process_video(input_video_path):
74
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
75
  output_video_filename = f"output_{timestamp}.mp4"
76
  output_video_path = os.path.join(output_dir, output_video_filename)
77
- # print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
78
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
79
 
80
  while True:
@@ -84,15 +113,15 @@ def process_video(input_video_path):
84
 
85
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
  pil_image = Image.fromarray(rgb_frame)
87
- # print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
88
  annotated_frame = get_annotated_image(pil_image)
89
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
90
- # print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
91
  # 确保帧的尺寸与视频输出一致
92
  if bgr_frame.shape[:2] != (height, width):
93
  bgr_frame = cv2.resize(bgr_frame, (width, height))
94
 
95
- # print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
96
  out.write(bgr_frame)
97
 
98
  cap.release()
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from PIL import Image,ImageDraw, ImageFont
5
  from transformers import pipeline
6
  import torch
7
  from random import choice
 
28
 
29
  def query_data(in_pil_img: Image.Image):
30
  results = detector(in_pil_img)
31
+ print(f"检测结果:{results}")
32
  return results
33
 
34
+
35
+ def get_font_size(box_width, min_size=10, max_size=48):
36
+ """根据边界框宽度计算合适的字体大小"""
37
+ # 字体大小取决于边界框宽度,取值最小为24
38
+ font_size = max(24,int(box_width / 10))
39
+ return max(min(font_size, max_size), min_size)
40
+
41
+ def get_text_position(box, text_bbox):
42
+ """根据边界框和文本边界框返回适当的位置"""
43
+ xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
44
+ text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
45
+
46
+ # 尝试将文本放置在边界框上方,但如果空间不足,则放置在边界框内
47
+ if ymin - text_height >= 0:
48
+ return (xmin, ymin - text_height) # 上方
49
+ else:
50
+ return (xmin, ymin) # 内部
51
+
52
  def get_annotated_image(in_pil_img):
53
  draw = ImageDraw.Draw(in_pil_img)
54
  in_results = query_data(in_pil_img)
 
58
  label = prediction['label']
59
  score = round(prediction['score'] * 100, 1)
60
  if score < 50:
61
+ continue # 过滤掉低置信度的预测结果
62
 
63
+ if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
64
  color = choice(COLORS)
65
  label_color_dict[label] = color
66
  else:
67
  color = label_color_dict[label]
68
 
69
+ # 计算字体大小
70
+ box_width = box['xmax'] - box['xmin']
71
+ font_size = get_font_size(box_width)
72
+ font = ImageFont.truetype("arial.ttf", size=font_size) # 确保你有可用的字体文件
73
+
74
+ # 获取文本边界框
75
+ text = f"{label}: {score}%"
76
+ text_bbox = draw.textbbox((0, 0), text, font=font)
77
+
78
  # 绘制矩形
79
  draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3)
80
+
81
  # 添加文本
82
+ text_pos = get_text_position(box, text_bbox)
83
+ draw.text(text_pos, text, fill=color, font=font)
84
 
85
  # 返回的是原始图像对象,它已经被修改了
86
  return np.array(in_pil_img.convert('RGB'))
87
 
88
+
89
  def process_video(input_video_path):
90
  cap = cv2.VideoCapture(input_video_path)
91
  if not cap.isOpened():
 
103
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
104
  output_video_filename = f"output_{timestamp}.mp4"
105
  output_video_path = os.path.join(output_dir, output_video_filename)
106
+ print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
107
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
108
 
109
  while True:
 
113
 
114
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
115
  pil_image = Image.fromarray(rgb_frame)
116
+ print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
117
  annotated_frame = get_annotated_image(pil_image)
118
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
119
+ print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
120
  # 确保帧的尺寸与视频输出一致
121
  if bgr_frame.shape[:2] != (height, width):
122
  bgr_frame = cv2.resize(bgr_frame, (width, height))
123
 
124
+ print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
125
  out.write(bgr_frame)
126
 
127
  cap.release()