Spaces:
Running
Running
from functools import partial | |
from datasets import load_dataset | |
import torch | |
from torchmetrics.functional.multimodal import clip_score | |
def load_prompts(num_prompts, batch_size): | |
"""Generate prompts for CLIP Score metric. | |
Args: | |
num_prompts (int): number of prompts to generate. | |
If num_prompts == 0, returns all prompts instead. | |
batch_size (int): batch size for prompts | |
Returns: | |
A tuple (prompts, batched_prompts) where prompts is a list of prompts | |
of length num_prompts (if num_prompts != 0) or the list of all prompts | |
(if num_prompts == 0), and batched_prompts is the list of prompts, | |
batched into chunks of size batch_size each. | |
""" | |
prompts = load_dataset("nateraw/parti-prompts", split="train") | |
if num_prompts == 0: | |
num_prompts = len(prompts) | |
else: | |
prompts = prompts.shuffle() | |
prompts = prompts[:num_prompts]["Prompt"] | |
batched_prompts = [ | |
prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size) | |
] | |
if len(batched_prompts[-1]) < batch_size: | |
batched_prompts = batched_prompts[:-1] | |
prompts = [prompt for batch in batched_prompts for prompt in batch] | |
return prompts, batched_prompts | |
def calculate_clip_score(images, prompts): | |
"""Calculate CLIP Score metric. | |
Args: | |
images (np.ndarray): array of images | |
prompts (list): list of prompts, assumes same size as images | |
Returns: | |
The clip score across all images and prompts as a float. | |
""" | |
clip_score_fn = partial( | |
clip_score, model_name_or_path="openai/clip-vit-base-patch16" | |
) | |
images_int = (images * 255).astype("uint8") | |
clip = clip_score_fn( | |
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts | |
).detach() | |
return float(clip) | |