Ziv Pollak
commited on
Commit
•
5aadbcd
1
Parent(s):
2acfef6
fixes
Browse files
app.py
CHANGED
@@ -12,6 +12,9 @@ from mediapipe import solutions
|
|
12 |
from PIL import Image
|
13 |
|
14 |
import torch, torchvision
|
|
|
|
|
|
|
15 |
|
16 |
import matplotlib
|
17 |
matplotlib.use("Agg")
|
@@ -19,10 +22,6 @@ import matplotlib.pyplot as plt
|
|
19 |
|
20 |
cropped_image = []
|
21 |
analyzed_image = []
|
22 |
-
|
23 |
-
# colors for visualization
|
24 |
-
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
25 |
-
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
26 |
|
27 |
finetuned_classes = [
|
28 |
'iris',
|
@@ -41,20 +40,17 @@ options = vision.FaceLandmarkerOptions(base_options=base_options,
|
|
41 |
num_faces=1)
|
42 |
detector = vision.FaceLandmarker.create_from_options(options)
|
43 |
|
44 |
-
|
45 |
-
model.
|
|
|
|
|
|
|
|
|
46 |
|
47 |
def video_identity(video):
|
48 |
return video
|
49 |
|
50 |
|
51 |
-
#demo = gr.Interface(video_identity,
|
52 |
-
# gr.Video(shape = (1000,1000), source="webcam"),
|
53 |
-
# "playable_video")
|
54 |
-
|
55 |
-
|
56 |
-
import torchvision.transforms as T
|
57 |
-
|
58 |
# standard PyTorch mean-std input image normalization
|
59 |
transform = T.Compose([
|
60 |
T.Resize(800),
|
@@ -65,11 +61,12 @@ transform = T.Compose([
|
|
65 |
def handle_image(input_image):
|
66 |
global cropped_image, analyzed_image
|
67 |
cv2.imwrite("image.jpg", input_image)
|
68 |
-
image = mp.Image.create_from_file("image.jpg")
|
|
|
69 |
|
70 |
-
detection_result = detector.detect(image)
|
71 |
cropped_image = image.numpy_view().copy()
|
72 |
analyzed_image = image.numpy_view().copy()
|
|
|
73 |
|
74 |
face_landmarks_list = detection_result.face_landmarks
|
75 |
|
@@ -87,26 +84,9 @@ def handle_image(input_image):
|
|
87 |
cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
|
88 |
cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
|
89 |
|
90 |
-
run_worflow(cropped_image, model)
|
91 |
-
|
92 |
-
return (cropped_image)
|
93 |
-
|
94 |
-
def load_model():
|
95 |
-
print('load model')
|
96 |
-
'''
|
97 |
-
model = torch.hub.load('facebookresearch/detr',
|
98 |
-
'detr_resnet50',
|
99 |
-
pretrained=False,
|
100 |
-
num_classes=1)
|
101 |
-
|
102 |
-
checkpoint = torch.load('outputs/checkpoint.pth',
|
103 |
-
map_location='cpu')
|
104 |
|
105 |
-
model.load_state_dict(checkpoint['model'],
|
106 |
-
strict=False)
|
107 |
-
|
108 |
-
model.eval();
|
109 |
-
'''
|
110 |
def filter_bboxes_from_outputs(img,
|
111 |
outputs,
|
112 |
threshold=0.7
|
@@ -124,22 +104,13 @@ def filter_bboxes_from_outputs(img,
|
|
124 |
return probas_to_keep, bboxes_scaled
|
125 |
|
126 |
|
127 |
-
def plot_finetuned_results(
|
128 |
-
plt.figure(figsize=(16,10))
|
129 |
-
plt.imshow(pil_img)
|
130 |
-
ax = plt.gca()
|
131 |
-
colors = COLORS * 100
|
132 |
if prob is not None and boxes is not None:
|
133 |
-
for p, (xmin, ymin, xmax, ymax)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
text = 'results'
|
139 |
-
ax.text(xmin, ymin, text, fontsize=15,
|
140 |
-
bbox=dict(facecolor='yellow', alpha=0.5))
|
141 |
-
plt.axis('off')
|
142 |
-
plt.show()
|
143 |
|
144 |
def rescale_bboxes(out_bbox, size):
|
145 |
print (size)
|
@@ -167,15 +138,20 @@ def run_worflow(my_image, my_model):
|
|
167 |
# propagate through the model
|
168 |
outputs = my_model(img)
|
169 |
|
170 |
-
|
|
|
|
|
171 |
|
172 |
probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(my_image,
|
173 |
outputs,
|
174 |
threshold=threshold)
|
175 |
|
176 |
-
|
|
|
177 |
probas_to_keep,
|
178 |
bboxes_scaled)
|
|
|
|
|
179 |
|
180 |
|
181 |
|
|
|
12 |
from PIL import Image
|
13 |
|
14 |
import torch, torchvision
|
15 |
+
import torchvision.transforms as T
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
|
19 |
import matplotlib
|
20 |
matplotlib.use("Agg")
|
|
|
22 |
|
23 |
cropped_image = []
|
24 |
analyzed_image = []
|
|
|
|
|
|
|
|
|
25 |
|
26 |
finetuned_classes = [
|
27 |
'iris',
|
|
|
40 |
num_faces=1)
|
41 |
detector = vision.FaceLandmarker.create_from_options(options)
|
42 |
|
43 |
+
# Loading the model
|
44 |
+
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False, num_classes=1)
|
45 |
+
hf_hub_download(repo_id="zivpollak/ECXV001", filename="checkpoint.pth", local_dir='.')
|
46 |
+
checkpoint = torch.load('checkpoint.pth', map_location='cpu')
|
47 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
48 |
+
model.eval()
|
49 |
|
50 |
def video_identity(video):
|
51 |
return video
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# standard PyTorch mean-std input image normalization
|
55 |
transform = T.Compose([
|
56 |
T.Resize(800),
|
|
|
61 |
def handle_image(input_image):
|
62 |
global cropped_image, analyzed_image
|
63 |
cv2.imwrite("image.jpg", input_image)
|
64 |
+
#image = mp.Image.create_from_file("image.jpg")
|
65 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
66 |
|
|
|
67 |
cropped_image = image.numpy_view().copy()
|
68 |
analyzed_image = image.numpy_view().copy()
|
69 |
+
detection_result = detector.detect(image)
|
70 |
|
71 |
face_landmarks_list = detection_result.face_landmarks
|
72 |
|
|
|
84 |
cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
|
85 |
cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
|
86 |
|
87 |
+
output_image = run_worflow(cropped_image, model)
|
88 |
+
return (output_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
def filter_bboxes_from_outputs(img,
|
91 |
outputs,
|
92 |
threshold=0.7
|
|
|
104 |
return probas_to_keep, bboxes_scaled
|
105 |
|
106 |
|
107 |
+
def plot_finetuned_results(img, prob=None, boxes=None):
|
|
|
|
|
|
|
|
|
108 |
if prob is not None and boxes is not None:
|
109 |
+
for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
|
110 |
+
print("adding rectangle")
|
111 |
+
cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 255, 255), 1)
|
112 |
+
return img
|
113 |
+
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
def rescale_bboxes(out_bbox, size):
|
116 |
print (size)
|
|
|
138 |
# propagate through the model
|
139 |
outputs = my_model(img)
|
140 |
|
141 |
+
output_image = cv2.imread("img1.jpg")
|
142 |
+
|
143 |
+
for threshold in [0.4, 0.4]:
|
144 |
|
145 |
probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(my_image,
|
146 |
outputs,
|
147 |
threshold=threshold)
|
148 |
|
149 |
+
print(bboxes_scaled)
|
150 |
+
output_image = plot_finetuned_results(output_image,
|
151 |
probas_to_keep,
|
152 |
bboxes_scaled)
|
153 |
+
|
154 |
+
return output_image
|
155 |
|
156 |
|
157 |
|