Ziv Pollak commited on
Commit
5aadbcd
1 Parent(s): 2acfef6
Files changed (1) hide show
  1. app.py +27 -51
app.py CHANGED
@@ -12,6 +12,9 @@ from mediapipe import solutions
12
  from PIL import Image
13
 
14
  import torch, torchvision
 
 
 
15
 
16
  import matplotlib
17
  matplotlib.use("Agg")
@@ -19,10 +22,6 @@ import matplotlib.pyplot as plt
19
 
20
  cropped_image = []
21
  analyzed_image = []
22
-
23
- # colors for visualization
24
- COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
25
- [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
26
 
27
  finetuned_classes = [
28
  'iris',
@@ -41,20 +40,17 @@ options = vision.FaceLandmarkerOptions(base_options=base_options,
41
  num_faces=1)
42
  detector = vision.FaceLandmarker.create_from_options(options)
43
 
44
- model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
45
- model.eval();
 
 
 
 
46
 
47
  def video_identity(video):
48
  return video
49
 
50
 
51
- #demo = gr.Interface(video_identity,
52
- # gr.Video(shape = (1000,1000), source="webcam"),
53
- # "playable_video")
54
-
55
-
56
- import torchvision.transforms as T
57
-
58
  # standard PyTorch mean-std input image normalization
59
  transform = T.Compose([
60
  T.Resize(800),
@@ -65,11 +61,12 @@ transform = T.Compose([
65
  def handle_image(input_image):
66
  global cropped_image, analyzed_image
67
  cv2.imwrite("image.jpg", input_image)
68
- image = mp.Image.create_from_file("image.jpg")
 
69
 
70
- detection_result = detector.detect(image)
71
  cropped_image = image.numpy_view().copy()
72
  analyzed_image = image.numpy_view().copy()
 
73
 
74
  face_landmarks_list = detection_result.face_landmarks
75
 
@@ -87,26 +84,9 @@ def handle_image(input_image):
87
  cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
88
  cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
89
 
90
- run_worflow(cropped_image, model)
91
-
92
- return (cropped_image)
93
-
94
- def load_model():
95
- print('load model')
96
- '''
97
- model = torch.hub.load('facebookresearch/detr',
98
- 'detr_resnet50',
99
- pretrained=False,
100
- num_classes=1)
101
-
102
- checkpoint = torch.load('outputs/checkpoint.pth',
103
- map_location='cpu')
104
 
105
- model.load_state_dict(checkpoint['model'],
106
- strict=False)
107
-
108
- model.eval();
109
- '''
110
  def filter_bboxes_from_outputs(img,
111
  outputs,
112
  threshold=0.7
@@ -124,22 +104,13 @@ def filter_bboxes_from_outputs(img,
124
  return probas_to_keep, bboxes_scaled
125
 
126
 
127
- def plot_finetuned_results(pil_img, prob=None, boxes=None):
128
- plt.figure(figsize=(16,10))
129
- plt.imshow(pil_img)
130
- ax = plt.gca()
131
- colors = COLORS * 100
132
  if prob is not None and boxes is not None:
133
- for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
134
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
135
- fill=False, color=c, linewidth=3))
136
- cl = p.argmax()
137
- #text = f'{finetuned_classes[cl]}: {p[cl]:0.2f}'
138
- text = 'results'
139
- ax.text(xmin, ymin, text, fontsize=15,
140
- bbox=dict(facecolor='yellow', alpha=0.5))
141
- plt.axis('off')
142
- plt.show()
143
 
144
  def rescale_bboxes(out_bbox, size):
145
  print (size)
@@ -167,15 +138,20 @@ def run_worflow(my_image, my_model):
167
  # propagate through the model
168
  outputs = my_model(img)
169
 
170
- for threshold in [0.2, 0.2]:
 
 
171
 
172
  probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(my_image,
173
  outputs,
174
  threshold=threshold)
175
 
176
- plot_finetuned_results(my_image,
 
177
  probas_to_keep,
178
  bboxes_scaled)
 
 
179
 
180
 
181
 
 
12
  from PIL import Image
13
 
14
  import torch, torchvision
15
+ import torchvision.transforms as T
16
+ from huggingface_hub import hf_hub_download
17
+
18
 
19
  import matplotlib
20
  matplotlib.use("Agg")
 
22
 
23
  cropped_image = []
24
  analyzed_image = []
 
 
 
 
25
 
26
  finetuned_classes = [
27
  'iris',
 
40
  num_faces=1)
41
  detector = vision.FaceLandmarker.create_from_options(options)
42
 
43
+ # Loading the model
44
+ model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False, num_classes=1)
45
+ hf_hub_download(repo_id="zivpollak/ECXV001", filename="checkpoint.pth", local_dir='.')
46
+ checkpoint = torch.load('checkpoint.pth', map_location='cpu')
47
+ model.load_state_dict(checkpoint['model'], strict=False)
48
+ model.eval()
49
 
50
  def video_identity(video):
51
  return video
52
 
53
 
 
 
 
 
 
 
 
54
  # standard PyTorch mean-std input image normalization
55
  transform = T.Compose([
56
  T.Resize(800),
 
61
  def handle_image(input_image):
62
  global cropped_image, analyzed_image
63
  cv2.imwrite("image.jpg", input_image)
64
+ #image = mp.Image.create_from_file("image.jpg")
65
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
66
 
 
67
  cropped_image = image.numpy_view().copy()
68
  analyzed_image = image.numpy_view().copy()
69
+ detection_result = detector.detect(image)
70
 
71
  face_landmarks_list = detection_result.face_landmarks
72
 
 
84
  cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
85
  cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
86
 
87
+ output_image = run_worflow(cropped_image, model)
88
+ return (output_image)
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
 
 
90
  def filter_bboxes_from_outputs(img,
91
  outputs,
92
  threshold=0.7
 
104
  return probas_to_keep, bboxes_scaled
105
 
106
 
107
+ def plot_finetuned_results(img, prob=None, boxes=None):
 
 
 
 
108
  if prob is not None and boxes is not None:
109
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
110
+ print("adding rectangle")
111
+ cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 255, 255), 1)
112
+ return img
113
+
 
 
 
 
 
114
 
115
  def rescale_bboxes(out_bbox, size):
116
  print (size)
 
138
  # propagate through the model
139
  outputs = my_model(img)
140
 
141
+ output_image = cv2.imread("img1.jpg")
142
+
143
+ for threshold in [0.4, 0.4]:
144
 
145
  probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(my_image,
146
  outputs,
147
  threshold=threshold)
148
 
149
+ print(bboxes_scaled)
150
+ output_image = plot_finetuned_results(output_image,
151
  probas_to_keep,
152
  bboxes_scaled)
153
+
154
+ return output_image
155
 
156
 
157