wwen1997's picture
Upload 13 files
7615afe verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import numpy as np
import cv2
import torch
import torch.nn.functional as F
def gen_gaussian_heatmap(imgSize=200):
circle_img = np.zeros((imgSize, imgSize), np.float32)
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
# Guass Map
for i in range(imgSize):
for j in range(imgSize):
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
# isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40))
return isotropicGrayscaleImage
def draw_heatmap(img, center_coordinate, heatmap_template, side, width, height):
x1 = max(center_coordinate[0] - side, 1)
x2 = min(center_coordinate[0] + side, width - 1)
y1 = max(center_coordinate[1] - side, 1)
y2 = min(center_coordinate[1] + side, height - 1)
x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
if (x2 - x1) < 1 or (y2 - y1) < 1:
print(center_coordinate, "x1, x2, y1, y2", x1, x2, y1, y2)
return img
need_map = cv2.resize(heatmap_template, (x2-x1, y2-y1))
img[y1:y2,x1:x2] = need_map
return img
def generate_gassian_heatmap(pred_tracks, pred_visibility=None, image_size=None, side=20):
width, height = image_size
num_frames, num_points = pred_tracks.shape[:2]
point_index_list = [point_idx for point_idx in range(num_points)]
heatmap_template = gen_gaussian_heatmap()
image_list = []
for frame_idx in range(num_frames):
img = np.zeros((height, width), np.float32)
for point_idx in point_index_list:
px, py = pred_tracks[frame_idx, point_idx]
if px < 0 or py < 0 or px >= width or py >= height:
if (frame_idx == 0) or (frame_idx == num_frames - 1):
print(frame_idx, point_idx, px, py)
continue
if pred_visibility is not None:
if (not pred_visibility[frame_idx, point_idx]):
continue
img = draw_heatmap(img, (px, py), heatmap_template, side, width, height)
img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2RGB)
img = torch.from_numpy(img).permute(2, 0, 1).contiguous()
image_list.append(img)
video_gaussion_map = torch.stack(image_list, dim=0)
return video_gaussion_map