import gradio as gr from PIL import Image import torch from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb from vpt.launch import default_argument_parser from collections import OrderedDict import numpy as np import matplotlib.pyplot as plt import models import string import nltk nltk.download('punkt') nltk.download('averaged_perceptron_tagger') from nltk.tokenize import word_tokenize import torchvision import spacy # download the model spacy.cli.download("en_core_web_sm") # Load spaCy model nlp = spacy.load("en_core_web_sm") def extract_objects(prompt): doc = nlp(prompt) # Extract object nouns (including proper nouns and compound nouns) objects = set() for token in doc: # Check if the token is a noun or part of a named entity if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_: objects.add(token.text) # Check if the token is part of a compound noun if token.dep_ in {"compound"}: objects.add(token.head.text) return list(objects) args = default_argument_parser().parse_args() cfg = setup(args) multi_classes = True device = "cuda" if torch.cuda.is_available() else "cpu" Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False) state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device) # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v Ours.load_state_dict(new_state_dict) Ours.eval() print("Model loaded successfully") def run(sketch, caption, threshold, seed): # select a random seed between 1 and 10 for the color color_seed = np.random.randint(0, 4) # set the condidate classes here caption = caption.replace('\n',' ') classes = extract_objects(caption) # translator = str.maketrans('', '', string.punctuation) # caption = caption.translate(translator).lower() # words = word_tokenize(caption) # classes = get_noun_phrase(words) # print(classes) if len(classes) ==0 or multi_classes == False: classes = [caption] # print(classes) colors = plt.get_cmap("Set1").colors classes_colors = colors[color_seed:len(classes)+color_seed] sketch2 = sketch['composite'] # when the drawing tool is used if sketch2[:,:,0:3].sum() == 0: temp = sketch2[:,:,3] # invert it temp = 255 - temp sketch2 = np.repeat(temp[:, :, np.newaxis], 3, axis=2) temp2= np.full_like(temp, 255) sketch2 = np.dstack((sketch2, temp2)) sketch2 = np.array(sketch2) pil_img = Image.fromarray(sketch2).convert('RGB') sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device) # torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png') with torch.no_grad(): text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True) redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True) num_of_tokens = 3 with torch.no_grad(): sketch_features = Ours.encode_image(sketch_tensor, layers=[12], text_features=text_features - redundant_features, mode="test").squeeze(0) sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True) similarity = sketch_features @ (text_features - redundant_features).t() patches_similarity = similarity[0, num_of_tokens + 1:, :] pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu() # visualize_attention_maps_with_tokens(pixel_similarity, classes) pixel_similarity[pixel_similarity < threshold] = 0 pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1) # display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True) # Find the class index with the highest similarity for each pixel class_indices = np.argmax(pixel_similarity_array, axis=0) # Create an HSV image placeholder hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3) hsv_image[..., 2] = 1 # Set Value to 1 for a white base # Set the hue and value channels for i, color in enumerate(classes_colors): rgb_color = np.array(color).reshape(1, 1, 3) hsv_color = rgb_to_hsv(rgb_color) mask = class_indices == i if i < len(classes): # For the first N-2 classes, set color based on similarity hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value else: # For the last two classes, set pixels to black hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black hsv_image[..., 1][mask] = 0 # Saturation set to 0 hsv_image[..., 2][mask] = 0 # Value set to 0, making it black mask_tensor_org = sketch2[:,:,0]/255 hsv_image[mask_tensor_org>=0.5] = [0,0,1] # Convert the HSV image back to RGB to display and save rgb_image = hsv_to_rgb(hsv_image) if len(classes) > 1: # Calculate centroids and render class names for i, class_name in enumerate(classes): mask = class_indices == i if np.any(mask): y, x = np.nonzero(mask) centroid_x, centroid_y = np.mean(x), np.mean(y) plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i] bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8)) # Display the image with class names plt.imshow(rgb_image) plt.axis('off') plt.tight_layout() # plt.savefig(f'poster_vis/{classes[0]}.png', bbox_inches='tight', pad_inches=0) plt.savefig('output.png', bbox_inches='tight', pad_inches=0) plt.close() # rgb_image = Image.open(f'poster_vis/{classes[0]}.png') rgb_image = Image.open('output.png') return rgb_image scripts = """ async () => { // START gallery format // Get all image elements with the class "image" var images = document.querySelectorAll('.image_gallery'); var originalParent = document.querySelector('#component-0'); // Create a new parent div element var parentDiv = document.createElement('div'); var beforeDiv= document.querySelector('.table-wrap').parentElement; parentDiv.id = "gallery_container"; // Loop through each image, append it to the parent div, and remove it from its original parent images.forEach(function(image , index ) { // Append the image to the parent div parentDiv.appendChild(image); // Add click event listener to each image image.addEventListener('click', function() { let nth_ch = index+1 document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click() console.log('.tr-body:nth-child(' + nth_ch + ')'); }); // Remove the image from its original parent }); // Get a reference to the original parent of the images var originalParent = document.querySelector('#component-0'); // Append the new parent div to the original parent originalParent.insertBefore(parentDiv, beforeDiv); // END gallery format // START confidence span // Get the selected div (replace 'selectedDivId' with the actual ID of your div) var selectedDiv = document.querySelector("label[for='range_id_0'] > span") // Get the text content of the div var textContent = selectedDiv.textContent; // Find the text before the first colon ':' var colonIndex = textContent.indexOf(':'); var textBeforeColon = textContent.substring(0, colonIndex); // Wrap the text before colon with a span element var spanElement = document.createElement('span'); spanElement.textContent = textBeforeColon; // Replace the original text with the modified text containing the span selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML); // START format the column names : // Get all elements with the class "test_class" var elements = document.querySelectorAll('.tr-head > th'); // Iterate over each element elements.forEach(function(element) { // Get the text content of the element var text = element.textContent.trim(); // Remove ":" from the text var wordWithoutColon = text.replace(':', ''); // Split the text into words var words = wordWithoutColon.split(' '); // Keep only the first word var firstWord = words[0]; // Set the text content of the element to the first word element.textContent = firstWord; }); document.querySelector('input[type=number]').disabled = true; } """ css=""" gradio-app { background-color: white !important; } .white-bg { background-color: white !important; } .gray-border { border: 1px solid dimgrey !important; } .border-radius { border-radius: 8px !important; } .black-text { color : black !important; } th { color : black !important; } tr { background-color: white !important; color: black !important; } td { border-bottom : 1px solid black !important; } label[data-testid="block-label"] { background: white; color: black; font-weight: bold; } .controls-wrap button:disabled { color: gray !important; background-color: white !important; } .controls-wrap button:not(:disabled) { color: black !important; background-color: white !important; } .source-wrap button { color: black !important; } .toolbar-wrap button { color: black !important; } .empty.wrap { color: black !important; } textarea { background-color : #f7f9f8 !important; color : #afb0b1 !important } input[data-testid="number-input"] { background-color : #f7f9f8 !important; color : black !important } tr > th { border-bottom : 1px solid black !important; } tr:hover { background: #f7f9f8 !important; } #component-19{ justify-content: center !important; } #component-19 > button { flex: none !important; background-color : black !important; font-weight: bold !important; } .bold { font-weight: bold !important; } span[data-testid="block-info"]{ color: black !important; font-weight: bold !important; } #component-14 > div { background-color : white !important; } button[aria-label="Clear"] { background-color : white !important; color: black !important; } #gallery_container { display: flex; flex-wrap: wrap; justify-content: start; } .image_gallery { margin-bottom: 1rem; margin-right: 1rem; } label[for='range_id_0'] > span > span { text-decoration: underline; } label[for='range_id_0'] > span > span { font-size: normal !important; } .underline { text-decoration: underline; } .mt-mb-1{ margin-top: 1rem; margin-bottom: 1rem; } #gallery_container + div { visibility: hidden; height: 10px; } input[type=number][disabled] { background-color: rgb(247, 249, 248) !important; color: black !important; -webkit-text-fill-color: black !important; } #component-13 { display: flex; flex-direction: column; align-items: center; } """ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo: gr.HTML("