# import spaces import gradio as gr import logging import os import datamapplot import numpy as np from dotenv import load_dotenv from gradio_huggingfacehub_search import HuggingfaceHubSearch from bertopic import BERTopic from bertopic.representation import KeyBERTInspired from huggingface_hub import HfApi, InferenceClient from sklearn.feature_extraction.text import CountVectorizer from sentence_transformers import SentenceTransformer from torch import cuda from src.hub import create_space_with_content from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT from src.viewer_api import ( get_split_rows, get_parquet_urls, get_docs_from_parquet, get_info, ) # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000")) CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000")) DATASETS_TOPICS_ORGANIZATION = os.getenv( "DATASETS_TOPICS_ORGANIZATION", "datasets-topics" ) USE_CUML = int(os.getenv("USE_CUML", "1")) USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1")) # Use cuml lib only if configured if USE_CUML: from cuml.manifold import UMAP from cuml.cluster import HDBSCAN else: from umap import UMAP from hdbscan import HDBSCAN logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) api = HfApi(token=HF_TOKEN) sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # Representation model model_id = "meta-llama/Meta-Llama-3-8B-Instruct" representation_model = KeyBERTInspired() vectorizer_model = CountVectorizer(stop_words="english") inference_client = InferenceClient(model_id) def calculate_embeddings(docs): return sentence_model.encode(docs, show_progress_bar=True, batch_size=32) def calculate_n_neighbors_and_components(n_rows): n_neighbors = min(max(n_rows // 20, 15), 100) n_components = 10 if n_rows > 1000 else 5 # Higher components for larger datasets return n_neighbors, n_components def fit_model(docs, embeddings, n_neighbors, n_components): umap_model = UMAP( n_neighbors=n_neighbors, n_components=n_components, min_dist=0.0, metric="cosine", random_state=42, ) hdbscan_model = HDBSCAN( min_cluster_size=max( 5, n_neighbors // 2 ), # Reducing min_cluster_size for fewer outliers metric="euclidean", cluster_selection_method="eom", prediction_data=True, ) new_model = BERTopic( language="english", # Sub-models embedding_model=sentence_model, # Step 1 - Extract embeddings umap_model=umap_model, # Step 2 - UMAP model hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics representation_model=representation_model, # Step 5 - Label topics # Hyperparameters top_n_words=10, verbose=True, min_topic_size=n_neighbors, # Coherent with n_neighbors? ) logging.info("Fitting new model") new_model.fit(docs, embeddings) logging.info("End fitting new model") return new_model # @spaces.GPU(duration=60 * 5) def generate_topics(dataset, config, split, column, plot_type): logging.info( f"Generating topics for {dataset=} {config=} {split=} {column=} {plot_type=}" ) parquet_urls = get_parquet_urls(dataset, config, split) split_rows = get_split_rows(dataset, config, split) if split_rows is None or split_rows == 0: return ( gr.Accordion(open=True), gr.DataFrame(value=[], interactive=False, visible=True), gr.Plot(value=None, visible=True), gr.Label( {"❌ Error: No data found for the selected dataset": 0.0}, visible=True ), "", ) logging.info(f"Split number of rows: {split_rows}") limit = min(split_rows, MAX_ROWS) n_neighbors, n_components = calculate_n_neighbors_and_components(limit) reduce_umap_model = UMAP( n_neighbors=n_neighbors, n_components=2, # For visualization, keeping it for 2D min_dist=0.0, metric="cosine", random_state=42, ) offset = 0 rows_processed = 0 base_model = None all_docs = [] reduced_embeddings_list = [] topics_info, topic_plot = None, None full_processing = split_rows <= MAX_ROWS message = ( f"Processing topics for full dataset: 0 of ({split_rows} rows)" if full_processing else f"Processing topics for partial dataset 0 of ({limit} rows)" ) sub_title = ( f"Data map for the entire dataset ({limit} rows) using the column '{column}'" if full_processing else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'" ) yield ( gr.Accordion(open=False), gr.DataFrame(value=[], interactive=False, visible=True), gr.Plot(value=None, visible=True), gr.Label({"⏳ " + message: 0.0}, visible=True), "", ) while offset < limit: logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}") docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE) if not docs: break logging.info(f"Got {len(docs)} docs ✓") embeddings = calculate_embeddings(docs) new_model = fit_model(docs, embeddings, n_neighbors, n_components) if base_model is None: base_model = new_model logging.info( f"The following topics are newly found: {base_model.topic_labels_}" ) else: updated_model = BERTopic.merge_models([base_model, new_model]) nr_new_topics = len(set(updated_model.topics_)) - len( set(base_model.topics_) ) new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] logging.info(f"The following topics are newly found: {new_topics}") base_model = updated_model logging.info("Reducing embeddings to 2D") reduced_embeddings = reduce_umap_model.fit_transform(embeddings) reduced_embeddings_list.append(reduced_embeddings) logging.info("Reducing embeddings to 2D ✓") all_docs.extend(docs) reduced_embeddings_array = np.vstack(reduced_embeddings_list) topics_info = base_model.get_topic_info() all_topics = base_model.topics_ logging.info(f"Preparing topics {plot_type} plot") topic_plot = ( base_model.visualize_document_datamap( docs=all_docs, topics=all_topics, reduced_embeddings=reduced_embeddings_array, title="", sub_title=sub_title, width=800, height=700, arrowprops={ "arrowstyle": "wedge,tail_width=0.5", "connectionstyle": "arc3,rad=0.05", "linewidth": 0, "fc": "#33333377", }, dynamic_label_size=True, # label_wrap_width=12, label_over_points=True, max_font_size=36, min_font_size=4, ) if plot_type == "DataMapPlot" else base_model.visualize_documents( docs=all_docs, topics=all_topics, reduced_embeddings=reduced_embeddings_array, title="", ) ) logging.info("Plot done ✓") rows_processed += len(docs) progress = min(rows_processed / limit, 1.0) logging.info(f"Progress: {progress} % - {rows_processed} of {limit}") message = ( f"Processing topics for full dataset: {rows_processed} of {limit}" if full_processing else f"Processing topics for partial dataset: {rows_processed} of {limit} rows" ) yield ( gr.Accordion(open=False), topics_info, topic_plot, gr.Label({"⏳ " + message: progress}, visible=True), "", ) offset += CHUNK_SIZE del docs, embeddings, new_model, reduced_embeddings logging.info("Finished processing all data") yield ( gr.Accordion(open=False), topics_info, topic_plot, gr.Label( { "✅ " + message: 1.0, f"⏳ Generating topic names with {model_id}": 0.0, }, visible=True, ), "", ) all_topics = base_model.topics_ topics_info = base_model.get_topic_info() new_topics_by_text_generation = {} for _, row in topics_info.iterrows(): logging.info( f"Processing topic: {row['Topic']} - Representation: {row['Representation']}" ) prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}" prompt_messages = [ { "role": "system", "content": "You are a helpful, respectful and honest assistant for labeling topics.", }, {"role": "user", "content": prompt}, ] output = inference_client.chat_completion( messages=prompt_messages, stream=False, max_tokens=500, top_p=0.8, seed=42, ) inference_response = output.choices[0].message.content logging.info("Inference response:") logging.info(inference_response) new_topics_by_text_generation[row["Topic"]] = inference_response.replace( "Topic=", "" ).strip() base_model.set_topic_labels(new_topics_by_text_generation) topics_info = base_model.get_topic_info() topic_plot = ( base_model.visualize_document_datamap( docs=all_docs, topics=all_topics, custom_labels=True, reduced_embeddings=reduced_embeddings_array, title="", sub_title=sub_title, width=800, height=700, arrowprops={ "arrowstyle": "wedge,tail_width=0.5", "connectionstyle": "arc3,rad=0.05", "linewidth": 0, "fc": "#33333377", }, dynamic_label_size=True, # label_wrap_width=12, label_over_points=True, max_font_size=36, min_font_size=4, ) if plot_type == "DataMapPlot" else base_model.visualize_documents( docs=all_docs, reduced_embeddings=reduced_embeddings_array, custom_labels=True, title="", ) ) dataset_clear_name = dataset.replace("/", "-") plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png" if plot_type == "DataMapPlot": topic_plot.savefig(plot_png, format="png", dpi=300) else: topic_plot.write_image(plot_png) custom_labels = base_model.custom_labels_ topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics] yield ( gr.Accordion(open=False), topics_info, topic_plot, gr.Label( { "✅ " + message: 1.0, f"✅ Generating topic names with {model_id}": 1.0, "⏳ Creating Interactive Space": 0.0, }, visible=True, ), "", ) interactive_plot = datamapplot.create_interactive_plot( reduced_embeddings_array, topic_names_array, hover_text=all_docs, title=dataset, sub_title=sub_title.replace( "dataset", f"dataset", ), enable_search=True, # TODO: Export data to .arrow and also serve it inline_data=True, # offline_data_prefix=dataset_clear_name, initial_zoom_fraction=0.9, cluster_boundary_polygons=True ) html_content = str(interactive_plot) html_file_path = f"{dataset_clear_name}.html" with open(html_file_path, "w", encoding="utf-8") as html_file: html_file.write(html_content) repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}" space_id = create_space_with_content( api=api, repo_id=repo_id, dataset_id=dataset, html_file_path=html_file_path, plot_file_path=plot_png, space_card=SPACE_REPO_CARD_CONTENT, token=HF_TOKEN, ) space_link = f"https://huggingface.co/spaces/{space_id}" yield ( gr.Accordion(open=False), topics_info, topic_plot, gr.Label( { "✅ " + message: 1.0, f"✅ Generating topic names with {model_id}": 1.0, "✅ Creating Interactive Space": 1.0, }, visible=True, ), f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})", ) del reduce_umap_model, all_docs, reduced_embeddings_list del ( base_model, all_topics, topics_info, topic_plot, topic_names_array, interactive_plot, ) cuda.empty_cache() with gr.Blocks() as demo: gr.HTML("

💠 Dataset Topic Discovery 🔭

") gr.HTML( "

Select a dataset and text column for topic modeling

" ) gr.HTML( "

⚠ This space is in progress, and we're actively working on it, so you might find some bugs! Please report any issues you have in the Community tab to help us make it better for all.

" ) data_details_accordion = gr.Accordion("Data details", open=True) with data_details_accordion: with gr.Row(): with gr.Column(scale=3): dataset_name = HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Search for dataset id on Huggingface", search_type="dataset", ) subset_dropdown = gr.Dropdown(label="Subset", visible=False) split_dropdown = gr.Dropdown(label="Split", visible=False) with gr.Accordion("Dataset preview", open=False): @gr.render(inputs=[dataset_name, subset_dropdown, split_dropdown]) def embed(name, subset, split): html_code = f""" """ return gr.HTML(value=html_code) with gr.Row(): text_column_dropdown = gr.Dropdown(label="Text column name") plot_type_radio = gr.Radio( ["DataMapPlot", "Plotly"], value="DataMapPlot", label="Choose the plot type", interactive=True, ) generate_button = gr.Button("Generate Topics", variant="primary") gr.Markdown("## Data map") full_topics_generation_label = gr.Label(visible=False, show_label=False) open_space_label = gr.Markdown() topics_plot = gr.Plot() with gr.Accordion("Topics Info", open=False): topics_df = gr.DataFrame(interactive=False, visible=True) gr.HTML( f"

⚠ This space processes datasets in batches of {CHUNK_SIZE}, with a maximum of {MAX_ROWS} rows. If you need further assistance, please open a new issue in the Community tab.

" ) gr.Markdown( "_Powered by [bertopic](https://maartengr.github.io/BERTopic/index.html) [datamapplot](https://datamapplot.readthedocs.io/en/latest/) and [duckdb](https://duckdb.org/)_" ) generate_button.click( generate_topics, inputs=[ dataset_name, subset_dropdown, split_dropdown, text_column_dropdown, plot_type_radio, ], outputs=[ data_details_accordion, topics_df, topics_plot, full_topics_generation_label, open_space_label, ], ) def _resolve_dataset_selection( dataset: str, default_subset: str, default_split: str, text_feature ): if "/" not in dataset.strip().strip("/"): return { subset_dropdown: gr.Dropdown(visible=False), split_dropdown: gr.Dropdown(visible=False), text_column_dropdown: gr.Dropdown(label="Text column name"), } try: info_resp = get_info(dataset) except Exception: return { subset_dropdown: gr.Dropdown(visible=False), split_dropdown: gr.Dropdown(visible=False), text_column_dropdown: gr.Dropdown(label="Text column name"), } subsets: list[str] = list(info_resp) subset = default_subset if default_subset in subsets else subsets[0] splits: list[str] = list(info_resp[subset]["splits"]) split = default_split if default_split in splits else splits[0] features = info_resp[subset]["features"] def _is_string_feature(feature): return isinstance(feature, dict) and feature.get("dtype") == "string" text_features = [ feature_name for feature_name, feature in features.items() if _is_string_feature(feature) ] if not text_feature: return { subset_dropdown: gr.Dropdown( value=subset, choices=subsets, visible=len(subsets) > 1 ), split_dropdown: gr.Dropdown( value=split, choices=splits, visible=len(splits) > 1 ), text_column_dropdown: gr.Dropdown( choices=text_features, label="Text column name", ), } return { subset_dropdown: gr.Dropdown( value=subset, choices=subsets, visible=len(subsets) > 1 ), split_dropdown: gr.Dropdown( value=split, choices=splits, visible=len(splits) > 1 ), text_column_dropdown: gr.Dropdown( choices=text_features, label="Text column name" ), } @dataset_name.change( inputs=[dataset_name], outputs=[ subset_dropdown, split_dropdown, text_column_dropdown, ], ) def show_input_from_subset_dropdown(dataset: str) -> dict: return _resolve_dataset_selection( dataset, default_subset="default", default_split="train", text_feature=None ) @subset_dropdown.change( inputs=[dataset_name, subset_dropdown], outputs=[ subset_dropdown, split_dropdown, text_column_dropdown, ], ) def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: return _resolve_dataset_selection( dataset, default_subset=subset, default_split="train", text_feature=None ) @split_dropdown.change( inputs=[dataset_name, subset_dropdown, split_dropdown], outputs=[ subset_dropdown, split_dropdown, text_column_dropdown, ], ) def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: return _resolve_dataset_selection( dataset, default_subset=subset, default_split=split, text_feature=None ) @text_column_dropdown.change( inputs=[dataset_name, subset_dropdown, split_dropdown, text_column_dropdown], outputs=[ subset_dropdown, split_dropdown, text_column_dropdown, ], ) def show_input_from_text_column_dropdown( dataset: str, subset: str, split: str, text_column ) -> dict: return _resolve_dataset_selection( dataset, default_subset=subset, default_split=split, text_feature=text_column, ) demo.launch()