miscjose's picture
Uploading local repo
ac6138f
raw
history blame
5.47 kB
import os
import json
import gradio as gr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight):
query = [model.encode(query)]
weighted_scores = []
for key, value in embeddings.items():
objective_embedding = value['objective_embedding']
subjective_embeddings = value['subjective_embeddings']
objective_score = cosine_similarity(query, objective_embedding).item()
subjective_scores = cosine_similarity(query, subjective_embeddings)
max_score = 0
max_review_index = 0
for idx, score in enumerate(subjective_scores[0].tolist()):
weighted_score = ((objective_score * objective_weight)+(score * subjective_weight))
if weighted_score > max_score:
max_score = weighted_score
max_review_index = idx
weighted_scores.append((key, max_score, max_review_index))
return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n]
def filter_anime(embeddings, genres, themes, rating):
genres = set(genres)
themes = set(themes)
rating = set(rating)
filtered_anime = embeddings.copy()
for key, anime in embeddings.items():
anime_genres = set(anime['genres'])
anime_themes = set(anime['themes'])
anime_rating = set([anime['rating']])
if genres.intersection(anime_genres) or 'ALL' in genres:
pass
else:
filtered_anime.pop(key)
continue
if themes.intersection(anime_themes) or 'ALL' in themes:
pass
else:
filtered_anime.pop(key)
continue
if rating.intersection(anime_rating) or 'ALL' in rating:
pass
else:
filtered_anime.pop(key)
continue
return filtered_anime
def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight):
filtered_anime = filter_anime(embeddings, genres, themes, rating)
results = []
weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight))
for idx, (key, score, review_index) in enumerate(weighted_scores, start=1):
data = embeddings[key]
english = data['english']
description = data['description']
review = data['reviews'][review_index]['text']
image = data['image']
results.append(gr.Image(label=f"{english}",value=image, height=435, width=500, visible=True))
results.append(gr.Textbox(label=f"Recommendation {idx}: {english}", value=description, max_lines=7, visible=True))
results.append(gr.Textbox(label=f"Best User Review {idx}'",value=review, max_lines=7, visible=True))
for i in range(3*((15*3)-(3*number_of_recommendations))):
results.append("N/A")
return results
if __name__ == '__main__':
with open('./embeddings/data.json') as f:
data = json.load(f)
embeddings = data['embeddings']
filters = data['filters']
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(visible=True) as input_col:
query = gr.Textbox(label="What are you looking for?")
number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1)
genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL'])
themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL'])
rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['PG - Children','PG-13 - Teens 13 or older','G - All Ages','R - 17+ (violence & profanity)'])
objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.7, step=.1)
subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.3, step=.1)
submit_btn = gr.Button("Submit")
examples = gr.Examples(
examples=[
['A show about pirates with super powers in search of gold', 3, ['Action', 'Adventure', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2]
],
inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
)
outputs = []
with gr.Column():
for i in range(15):
with gr.Row():
with gr.Column():
outputs.append(gr.Image(f"Image {i}", height=435, width=500, visible=False))
with gr.Column():
outputs.append(gr.Textbox(label=f"Recommendation {i}", max_lines=7, visible=False))
outputs.append(gr.Textbox(label=f"Best User Review", max_lines=7, visible=False))
submit_btn.click(
get_recommendation,
[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
outputs
)
demo.launch()