Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import warnings | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
def gen_gaussian_heatmap(imgSize=200): | |
circle_img = np.zeros((imgSize, imgSize), np.float32) | |
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1) | |
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) | |
# Guass Map | |
for i in range(imgSize): | |
for j in range(imgSize): | |
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp( | |
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2))) | |
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask | |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) | |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8) | |
# isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40)) | |
return isotropicGrayscaleImage | |
def draw_heatmap(img, center_coordinate, heatmap_template, side, width, height): | |
x1 = max(center_coordinate[0] - side, 1) | |
x2 = min(center_coordinate[0] + side, width - 1) | |
y1 = max(center_coordinate[1] - side, 1) | |
y2 = min(center_coordinate[1] + side, height - 1) | |
x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2) | |
if (x2 - x1) < 1 or (y2 - y1) < 1: | |
print(center_coordinate, "x1, x2, y1, y2", x1, x2, y1, y2) | |
return img | |
need_map = cv2.resize(heatmap_template, (x2-x1, y2-y1)) | |
img[y1:y2,x1:x2] = need_map | |
return img | |
def generate_gassian_heatmap(pred_tracks, pred_visibility=None, image_size=None, side=20): | |
width, height = image_size | |
num_frames, num_points = pred_tracks.shape[:2] | |
point_index_list = [point_idx for point_idx in range(num_points)] | |
heatmap_template = gen_gaussian_heatmap() | |
image_list = [] | |
for frame_idx in range(num_frames): | |
img = np.zeros((height, width), np.float32) | |
for point_idx in point_index_list: | |
px, py = pred_tracks[frame_idx, point_idx] | |
if px < 0 or py < 0 or px >= width or py >= height: | |
if (frame_idx == 0) or (frame_idx == num_frames - 1): | |
print(frame_idx, point_idx, px, py) | |
continue | |
if pred_visibility is not None: | |
if (not pred_visibility[frame_idx, point_idx]): | |
continue | |
img = draw_heatmap(img, (px, py), heatmap_template, side, width, height) | |
img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2RGB) | |
img = torch.from_numpy(img).permute(2, 0, 1).contiguous() | |
image_list.append(img) | |
video_gaussion_map = torch.stack(image_list, dim=0) | |
return video_gaussion_map |