Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import tempfile | |
from bertopic import BERTopic | |
from src.reddit import RedditBot | |
from flask import Blueprint, render_template, request, send_file, redirect, url_for, send_from_directory | |
# DOWNLOADS_PATH = os.path.join(os.getcwd(), 'downloads') | |
views = Blueprint(__name__, 'views') | |
reddit = RedditBot() | |
topic_model = BERTopic() | |
def retrieve_subreddits(data: dict) -> pd.DataFrame: | |
# Retrieve subreddits through its API | |
posts = reddit.get_subreddits_posts( | |
name=data.get('subreddit'), | |
type=data.get('type'), | |
amount=int(data.get('amount')) | |
) | |
df = reddit.convert_posts_to_df(posts=posts) | |
df['Text'] = df.apply(lambda row: row.Title + ': ' + row.Content, axis=1) | |
return df | |
def home(): | |
data = request.form | |
if request.method == 'POST': | |
if (int(data.get('amount')) < 0 or int(data.get('amount')) > 1000): | |
return redirect(url_for('views.error', type_of_error='amount')) | |
elif data.get('type') not in ['hot', 'new', 'rising', 'top']: | |
print(data.get('type')) | |
return redirect(url_for('views.error', type_of_error='type')) | |
elif not reddit.subreddit_exists(data.get('subreddit')): | |
return redirect(url_for('views.error', type_of_error='subreddit')) | |
else: | |
# Retrieve subreddits | |
subreddits_df = retrieve_subreddits(data=data) | |
# Topic modelling using BERTtopic | |
_, _ = topic_model.fit_transform(subreddits_df.Text) | |
topics_df = topic_model.get_topic_info() | |
for t in topics_df.Topic: | |
topics_df.loc[topics_df.Topic == t, 'Top words'] = str([w for w, p in topic_model.get_topic(t)]) | |
# Donwload topics | |
# topics_df.to_csv(os.path.join(DOWNLOADS_PATH, 'topics.csv'), index=False) | |
topics_df.to_csv('topics.csv', index=False) | |
send_file('topics.csv', as_attachment=True) | |
# Download docs info | |
docs_df = topic_model.get_document_info(subreddits_df.Text) | |
docs_df.to_csv('docs_with_topics_info.csv', index=False) | |
send_file('docs_with_topics_info.csv', as_attachment=True) | |
return render_template('success.html', topics = [topics_df.to_html(classes='data')], titles=topics_df.columns.values) | |
return render_template('index.html') | |
def success(): | |
return render_template('success.html') | |
def error(type_of_error: str): | |
if type_of_error == 'amount': | |
return render_template('error.html', type_of_error='The amount is higher than 1000 or lower than 0') | |
elif type_of_error == 'type': | |
return render_template('error.html', type_of_error='The ordering is not within hot, rising, new, top') | |
elif type_of_error == 'subreddit': | |
return render_template('error.html', type_of_error='The subreddit does not exist') | |