scene-sketch-seg / utils.py
ahmedbrs's picture
multi-categories
37b5ba0
raw
history blame contribute delete
No virus
5.22 kB
import torch
import numpy as np
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from vpt.src.configs.config import get_cfg
import os
from time import sleep
from random import randint
from vpt.src.utils.file_io import PathManager
import matplotlib.pyplot as plt
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
import warnings
import nltk
warnings.filterwarnings("ignore")
def get_noun_phrase(tokenized):
# Taken from Su Nam Kim Paper...
grammar = r"""
NBAR:
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
NP:
{<NBAR>}
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
"""
chunker = nltk.RegexpParser(grammar)
chunked = chunker.parse(nltk.pos_tag(tokenized))
continuous_chunk = []
current_chunk = []
for subtree in chunked:
if isinstance(subtree, nltk.Tree):
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
elif current_chunk:
named_entity = ' '.join(current_chunk)
if named_entity not in continuous_chunk:
continuous_chunk.append(named_entity)
current_chunk = []
else:
continue
return continuous_chunk
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
output_dir = cfg.OUTPUT_DIR
lr = cfg.SOLVER.BASE_LR
wd = cfg.SOLVER.WEIGHT_DECAY
output_folder = os.path.join(
cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}")
# train cfg.RUN_N_TIMES times
count = 1
while count <= cfg.RUN_N_TIMES:
output_path = os.path.join(output_dir, output_folder, f"run{count}")
# pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
sleep(randint(3, 30))
if not PathManager.exists(output_path):
PathManager.mkdirs(output_path)
cfg.OUTPUT_DIR = output_path
break
else:
count += 1
cfg.freeze()
return cfg
def get_similarity_map(sm, shape):
# sm: torch.Size([1, 196, 1])
# min-max norm
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
# reshape
side = int(sm.shape[1] ** 0.5) # square output, side = 14
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
# interpolate
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
sm = sm.permute(0, 2, 3, 1)
return sm.squeeze(0)
def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classes_colors,save_path=None,live=False):
# Find the class index with the highest similarity for each pixel
class_indices = np.argmax(pixel_similarity_array, axis=0)
# Create an HSV image placeholder
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
hsv_image[..., 2] = 1 # Set Value to 1 for a white base
# Set the hue and value channels
for i, color in enumerate(classes_colors):
rgb_color = np.array(color).reshape(1, 1, 3)
hsv_color = rgb_to_hsv(rgb_color)
mask = class_indices == i
if i < len(classes): # For the first N-2 classes, set color based on similarity
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
else: # For the last two classes, set pixels to black
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
hsv_image[..., 1][mask] = 0 # Saturation set to 0
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
mask_tensor_org = binary_sketch[:,:,0]/255
hsv_image[mask_tensor_org==1] = [0,0,1]
# Convert the HSV image back to RGB to display and save
rgb_image = hsv_to_rgb(hsv_image)
if len(classes) > 1:
# Calculate centroids and render class names
for i, class_name in enumerate(classes):
mask = class_indices == i
if np.any(mask):
y, x = np.nonzero(mask)
centroid_x, centroid_y = np.mean(x), np.mean(y)
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
# Display the image with class names
plt.imshow(rgb_image)
plt.axis('off')
plt.tight_layout()
if live:
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
else:
save_dir = "/".join(save_path.split("/")[:-1])
if save_dir !='':
if not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
else:
plt.show()