File size: 5,466 Bytes
ac6138f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import json

import gradio as gr

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')


def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight):
    query = [model.encode(query)]

    weighted_scores = []

    for key, value in embeddings.items():
        objective_embedding = value['objective_embedding']
        subjective_embeddings = value['subjective_embeddings']
        
        objective_score = cosine_similarity(query, objective_embedding).item()
        subjective_scores = cosine_similarity(query, subjective_embeddings)

        max_score = 0
        max_review_index = 0
        for idx, score in enumerate(subjective_scores[0].tolist()):
            weighted_score = ((objective_score * objective_weight)+(score * subjective_weight))
            if weighted_score > max_score:
                max_score = weighted_score
                max_review_index = idx
        
        weighted_scores.append((key, max_score, max_review_index))
    
    return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n]

def filter_anime(embeddings, genres, themes, rating):
    genres = set(genres)
    themes = set(themes)
    rating = set(rating)

    filtered_anime = embeddings.copy()
    for key, anime in embeddings.items():

        anime_genres = set(anime['genres'])
        anime_themes = set(anime['themes'])
        anime_rating = set([anime['rating']])

        if genres.intersection(anime_genres) or 'ALL' in genres:
            pass
        else:
            filtered_anime.pop(key)
            continue
        if themes.intersection(anime_themes) or 'ALL' in themes:
            pass
        else:
            filtered_anime.pop(key)
            continue
        if rating.intersection(anime_rating) or 'ALL' in rating:
            pass
        else:
            filtered_anime.pop(key)
            continue
        
    return filtered_anime

def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight):
    filtered_anime = filter_anime(embeddings, genres, themes, rating)
    results = []
    weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight))
    for idx, (key, score, review_index) in enumerate(weighted_scores, start=1):
        data = embeddings[key]
        english = data['english']
        description = data['description']
        review = data['reviews'][review_index]['text']
        image = data['image']

        results.append(gr.Image(label=f"{english}",value=image, height=435, width=500, visible=True))
        results.append(gr.Textbox(label=f"Recommendation {idx}: {english}", value=description, max_lines=7, visible=True))
        results.append(gr.Textbox(label=f"Best User Review {idx}'",value=review, max_lines=7, visible=True))

    for i in range(3*((15*3)-(3*number_of_recommendations))):
        results.append("N/A")
    
    return results

if __name__ == '__main__':
    
    with open('./embeddings/data.json') as f:
        data = json.load(f)
        embeddings = data['embeddings']
        filters = data['filters']

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(visible=True) as input_col:
                query = gr.Textbox(label="What are you looking for?")
                number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1)
                genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL'])
                themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL'])
                rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['PG - Children','PG-13 - Teens 13 or older','G - All Ages','R - 17+ (violence & profanity)'])
                objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.7, step=.1)
                subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.3, step=.1)
                submit_btn = gr.Button("Submit")

                examples = gr.Examples(
                    examples=[
                        ['A show about pirates with super powers in search of gold', 3, ['Action', 'Adventure', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2]
                    ],
                    inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
                )

            outputs = []
            with gr.Column():
                for i in range(15):
                    with gr.Row():
                        with gr.Column():
                            outputs.append(gr.Image(f"Image {i}", height=435, width=500, visible=False))
                        with gr.Column():
                            outputs.append(gr.Textbox(label=f"Recommendation {i}", max_lines=7, visible=False))
                            outputs.append(gr.Textbox(label=f"Best User Review", max_lines=7, visible=False))

        submit_btn.click(
            get_recommendation,
            [query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight],
            outputs
        )

    demo.launch()