Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import torch | |
import cv2 | |
import torch.nn.functional as F | |
import numpy as np | |
import config as CFG | |
from datasets import get_transforms | |
#for running this script as main | |
from utils import get_datasets, build_loaders | |
from models import PoemTextModel | |
from utils import get_poem_embeddings | |
import json | |
import os | |
import regex | |
def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False): | |
""" | |
Returns n poems which are the most similar to a text query | |
Parameters: | |
----------- | |
model: PoemTextModel | |
model to compute text query's embeddings | |
poem_embeddings: sequence with shape (#poems, CFG.projection_dim) | |
poem embeddings to check similarity | |
query: str | |
text query | |
poems: list of str | |
poems corresponding to poem_embeddings | |
text_tokenizer: huggingface Tokenizer, optional | |
tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs. | |
n: int, optional | |
number of poems to return | |
return_similarities: bool, optional | |
if True, a dictionary will be returned which has the poem beyts and their similarities to the text | |
Returns: | |
-------- | |
A list of n poem strings whose embeddings are the most similar to query text's embedding. | |
""" | |
#Tokenizing and Encoding the query text | |
if not text_tokenizer: | |
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
encoded_query = text_tokenizer([query]) | |
batch = { | |
key: torch.tensor(values).to(CFG.device) | |
for key, values in encoded_query.items() | |
} | |
# getting query text's embeddings | |
model.eval() | |
with torch.no_grad(): | |
text_features = model.text_encoder( | |
input_ids= batch["input_ids"], attention_mask=batch["attention_mask"] | |
) | |
text_embeddings = model.text_projection(text_features) | |
# normalizing and computing dot similarity of poem and text embeddings | |
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
dot_similarity = text_embeddings_n @ poem_embeddings_n.T | |
# returning top n poems based on embedding similarity | |
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) | |
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings), | |
# so we must check the poems added to result not to be duplicates | |
def is_poem_duplicate(poem, poems): | |
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) | |
for other_poem in poems: | |
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) | |
if poem == other_poem: | |
return True | |
return False | |
results = [] | |
computed_k = 0 | |
for i in range(len(poems)): | |
if computed_k == n: | |
break | |
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): | |
results.append({ | |
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), | |
'similarity': values[i] | |
}) | |
computed_k += 1 | |
if return_similarities: | |
return results | |
else: | |
return [res['beyt'] for res in results] | |
def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False): | |
""" | |
Returns n poems which are the most similar to an image query | |
Parameters: | |
----------- | |
model: CLIPModel | |
model to compute image query's embeddings | |
poem_embeddings: sequence with shape (#poems, CFG.projection_dim) | |
poem embeddings to check similarity | |
image_filename: str | |
path and file name for the image query | |
poems: list of str | |
poems corresponding to poem_embeddings | |
n: int, optional | |
number of poems to return | |
return_similarities: bool, optional | |
if True, a dictionary will be returned which has the poem beyts and their similarities to the text | |
Returns: | |
-------- | |
A list of n poem strings whose embeddings are the most similar to image query's embedding. | |
""" | |
# Reading, Processing and applying transforms to image (all explained in datasets.py) | |
image = cv2.imread(f"{image_filename}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = get_transforms(mode="test")(image=image)['image'] | |
image = torch.tensor(image).permute(2, 0, 1).float() | |
# getting image query's embeddings | |
model.eval() | |
with torch.no_grad(): | |
image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device)) | |
image_embeddings = model.image_projection(image_features) | |
# normalizing and computing dot similarity of poem and text embeddings | |
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
dot_similarity = image_embeddings_n @ poem_embeddings_n.T | |
# returning top n poems based on embedding similarity | |
values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) | |
# since we collected poems from many sources, some of them are equal (the same beyt with different meanings), | |
# so we must check the poems added to result not to be duplicates | |
def is_poem_duplicate(poem, poems): | |
poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) | |
for other_poem in poems: | |
other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) | |
if poem == other_poem: | |
return True | |
return False | |
results = [] | |
computed_k = 0 | |
for i in range(len(poems)): | |
if computed_k == n: | |
break | |
if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): | |
results.append({ | |
'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), | |
'similarity': values[i] | |
}) | |
computed_k += 1 | |
if return_similarities: | |
return results | |
else: | |
return [res['beyt'] for res in results] | |
if __name__ == "__main__": | |
""" | |
Creates a PoemTextModel based on configs, and outputs some examples of its prediction. | |
""" | |
# get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) | |
train_dataset, val_dataset, test_dataset = get_datasets() | |
model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) | |
model.eval() | |
# Inference: Output some example predictions and write them in a file | |
print("_"*20) | |
print("Output Examples from test set") | |
model, poem_embeddings = get_poem_embeddings(test_dataset, model) | |
example = {} | |
for i, test_data in enumerate(test_dataset[:100]): | |
example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)} | |
for i in range(10): | |
print("Text: ", example[i]['Text']) | |
print("True Beyt: ", example[i]['True Beyt']) | |
print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"])) | |
with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
f.write(json.dumps(example, ensure_ascii=False, indent= 4)) | |
print("Preparing model for user input...") | |
with open(CFG.dataset_path, encoding="utf-8") as f: | |
dataset = json.load(f) | |
model, poem_embeddings = get_poem_embeddings(dataset, model) | |
while(True): | |
user_text = input("Enter a Text to find poem beyts for: ") | |
beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10) | |
print("predicted Beyts: \n\t", "\n\t".join(beyts)) | |
with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f: | |
f.write(json.dumps(beyts, ensure_ascii=False, indent= 4)) |