import joblib import streamlit as st import json import requests from bs4 import BeautifulSoup from datetime import date from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences import numpy as np from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch # load all the models and vectorizer (global vocabulary) Seq_model = load_model("LSTM.h5") # Sequential SVM_model = joblib.load("SVM_Linear_Kernel.joblib") # SVM logistic_model = joblib.load("Logistic_Model.joblib") # Logistic svm_model = joblib.load('svm_model.joblib') vectorizer = joblib.load("vectorizer.joblib") # global vocabulary (used for Logistic, SVC) tokenizer = joblib.load("tokenizer.joblib") # used for LSTM device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer1 = DistilBertTokenizer.from_pretrained("tokenizer_bert") model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5) model.load_state_dict(torch.load("fine_tuned_bert_model1.pth", map_location=device)) # Decode label function # {'business': 0, 'entertainment': 1, 'health': 2, 'politics': 3, 'sport': 4} def decodedLabel(input_number): print('receive label encoded', input_number) categories = { 0: 'Business', 1: 'Entertainment', 2: 'Health', 3: 'Politics', 4: 'Sport' } result = categories.get(input_number) # Ex: Health print('decoded result', result) return result # Web Crawler function def crawURL(url): # Fetch the URL content response = requests.get(url) # Parse the sitemap HTML soup = BeautifulSoup(response.content, 'html.parser') # Find all anchor tags that are children of span tags with class 'sitemap-link' urls = [span.a['href'] for span in soup.find_all('span', class_='sitemap-link') if span.a] # Crawl pages and extract data try: print(f"Crawling page: {url}") # Fetch page content page_response = requests.get(url) page_content = page_response.content # Parse page content with BeautifulSoup soup = BeautifulSoup(page_content, 'html.parser') # Extract data you need from the page author = soup.find("meta", {"name": "author"}).attrs['content'].strip() date_published = soup.find("meta", {"property": "article:published_time"}).attrs['content'].strip() article_section = soup.find("meta", {"name": "meta-section"}).attrs['content'] url = soup.find("meta", {"property": "og:url"}).attrs['content'] headline = soup.find("h1", {"data-editable": "headlineText"}).text.strip() description = soup.find("meta", {"name": "description"}).attrs['content'].strip() keywords = soup.find("meta", {"name": "keywords"}).attrs['content'].strip() text = soup.find(itemprop="articleBody") # Find all
tags with class "paragraph inline-placeholder" paragraphs = text.find_all('p', class_="paragraph inline-placeholder") # Initialize an empty list to store the text content of each paragraph paragraph_texts = [] # Iterate over each
tag and extract its text content for paragraph in paragraphs: paragraph_texts.append(paragraph.text.strip()) # Join the text content of all paragraphs into a single string full_text = ''.join(paragraph_texts) return full_text except Exception as e: print(f"Failed to crawl page: {url}, Error: {str(e)}") return None # Predict for text category by Models def process_api(text): # Vectorize the text data processed_text = vectorizer.transform([text]) sequence = tokenizer.texts_to_sequences([text]) padded_sequence = pad_sequences(sequence, maxlen=1000, padding='post') new_encoding = tokenizer1([text], truncation=True, padding=True, return_tensors="pt") input_ids = new_encoding['input_ids'] attention_mask = new_encoding['attention_mask'] with torch.no_grad(): output = model(input_ids, attention_mask=attention_mask) logits = output.logits # Get the predicted result from models Logistic_Predicted = logistic_model.predict(processed_text).tolist() # Logistic Model SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model Seq_Predicted = Seq_model.predict(padded_sequence) predicted_label_index = np.argmax(Seq_Predicted) # ----------- Proba ----------- Logistic_Predicted_proba = logistic_model.predict_proba(processed_text) svm_new_probs = SVM_model.decision_function(processed_text) svm_probs = svm_model.predict_proba(svm_new_probs) predicted_label_index = np.argmax(Seq_Predicted) bert_probabilities = torch.softmax(logits, dim=1) max_probability = torch.max(bert_probabilities).item() predicted_label_bert = torch.argmax(logits, dim=1).item() # ----------- Debug Logs ----------- logistic_debug = decodedLabel(int(Logistic_Predicted[0])) svc_debug = decodedLabel(int(SVM_Predicted[0])) # predicted_label_index = np.argmax(Seq_Predicted) #print('Logistic', int(Logistic_Predicted[0]), logistic_debug) #print('SVM', int(SVM_Predicted[0]), svc_debug) return { 'predicted_label_logistic': decodedLabel(int(Logistic_Predicted[0])), 'probability_logistic': f"{int(float(np.max(Logistic_Predicted_proba))*10000//100)}%", 'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])), 'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%", 'predicted_label_lstm': decodedLabel(int(predicted_label_index)), 'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%", 'predicted_label_bert': decodedLabel(int(predicted_label_bert)), 'probability_bert': f"{int(float(max_probability)*10000//100)}%", 'Article_Content': text } # Init web crawling, process article content by Model and return result as JSON def categorize(url): try: article_content = crawURL(url) result = process_api(article_content) return result except Exception as error: if hasattr(error, 'message'): return {"error_message": error.message} else: return {"error_message": error} # Main App st.title('Instant Category Classification') st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚡️") # Define category information (modify content and bullet points as needed) categories = { "Business": [ "Analyze market trends and investment opportunities.", "Gain insights into company performance and industry news.", "Stay informed about economic developments and regulations." ], "Health": [ "Discover healthy recipes and exercise tips.", "Learn about the latest medical research and advancements.", "Find resources for managing chronic conditions and improving well-being." ], "Sport": [ "Follow your favorite sports teams and athletes.", "Explore news and analysis from various sports categories.", "Stay updated on upcoming games and competitions." ], "Politics": [ "Get informed about current political events and policies.", "Understand different perspectives on political issues.", "Engage in discussions and debates about politics." ], "Entertainment": [ "Find recommendations for movies, TV shows, and music.", "Explore reviews and insights from entertainment critics.", "Stay updated on celebrity news and cultural trends." ] } # Define model information (modify descriptions as needed) models = { "Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.", "SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.", "LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.", "BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification." } # CNN URL Example List URL_Example = [ 'https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html', 'https://edition.cnn.com/2024/04/30/entertainment/barbra-streisand-melissa-mccarthy-ozempic/index.html', 'https://edition.cnn.com/2024/04/30/sport/lebron-james-lakers-future-nba-spt-intl/index.html', 'https://edition.cnn.com/2024/04/30/business/us-home-prices-rose-in-february/index.html' ] # Create expanders containing list of categories can be classified with st.expander("Category List"): # Title for each category st.subheader("Available Categories:") for category in categories.keys(): st.write(f"- {category}") # Content for each category (separated by a horizontal line) st.write("---") for category, content in categories.items(): st.subheader(category) for item in content: st.write(f"- {item}") # Create expanders containing list of models used in this project with st.expander("Available Models"): st.subheader("List of Models:") for model_name in models.keys(): st.write(f"- {model_name}") st.write("---") for model_name, description in models.items(): st.subheader(model_name) st.write(description) with st.expander("URLs Example"): for url in URL_Example: st.write(f"- {url}") # Explain to user why this project is only worked for CNN domain with st.expander("Tips", expanded=True): st.write( ''' This project works best with CNN articles right now. Our web crawler is like a special tool for CNN's website. It can't quite understand other websites because they're built differently ''' ) st.divider() # 👈 Draws a horizontal rule st.title('Dive in! See what category your CNN story belongs to 😉.') # Paste URL Input url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER 🔍.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html') if url: st.divider() # 👈 Draws a horizontal rule result = categorize(url) article_content = result.get('Article_Content') st.title('Article Content Fetched') st.text_area("", value=article_content, height=400) # render the article content as textarea element st.divider() # 👈 Draws a horizontal rule st.title('Predicted Results') st.json({ "Logistic": { "predicted_label": result.get("predicted_label_logistic"), "probability": result.get("probability_logistic") }, "SVC": { "predicted_label": result.get("predicted_label_svm"), "probability": result.get("probability_svm") }, "LSTM": { "predicted_label": result.get("predicted_label_lstm"), "probability": result.get("probability_lstm") }, "BERT": { "predicted_label": result.get("predicted_label_bert"), "probability": result.get("probability_bert") } }) st.divider() # 👈 Draws a horizontal rule # Category labels and corresponding counts categories = ["Sport", "Health", "Entertainment", "Politics", "Business"] counts = [5638, 4547, 2658, 2461, 1362] # Optional: Add a chart title st.title("Training Data Category Distribution") # Optional: Display additional information st.write("Here's a breakdown of the number of articles in each category:") for category, count in zip(categories, counts): st.write(f"- {category}: {count}") # Create the bar chart st.bar_chart(data=dict(zip(categories, counts))) st.divider() # 👈 Draws a horizontal rule # ------------ Copyright Section ------------ # Get the current year current_year = date.today().year # Format the copyright statement with dynamic year copyright_text = f"Copyright © {current_year}" st.title(copyright_text) author_names = ["Trần Thanh Phước (Mentor)", "Lương Ngọc Phương (Member)", "Trịnh Cẩm Minh (Member)"] st.write("Meet the minds behind the work!") for author in author_names: if (author == "Trịnh Cẩm Minh (Member)"): st.markdown("- [Trịnh Cẩm Minh (Member)](https://minhct.netlify.app/)") else: st.markdown(f"- {author}\n") # Use f-string for bullet and newline