{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: clustering\n", "### This demo built with Blocks generates 9 plots based on the input.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio matplotlib>=3.5.2 scikit-learn>=1.0.1 "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import math\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.cluster import (\n", " AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth\n", ")\n", "from sklearn.datasets import make_blobs, make_circles, make_moons\n", "from sklearn.mixture import GaussianMixture\n", "from sklearn.neighbors import kneighbors_graph\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "plt.style.use('seaborn-v0_8')\n", "SEED = 0\n", "MAX_CLUSTERS = 10\n", "N_SAMPLES = 1000\n", "N_COLS = 3\n", "FIGSIZE = 7, 7 # does not affect size in webpage\n", "COLORS = [\n", " 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'\n", "]\n", "if len(COLORS) <= MAX_CLUSTERS:\n", " raise ValueError(\"Not enough different colors for all clusters\")\n", "np.random.seed(SEED)\n", "\n", "\n", "def normalize(X):\n", " return StandardScaler().fit_transform(X)\n", "\n", "def get_regular(n_clusters):\n", " # spiral pattern\n", " centers = [\n", " [0, 0],\n", " [1, 0],\n", " [1, 1],\n", " [0, 1],\n", " [-1, 1],\n", " [-1, 0],\n", " [-1, -1],\n", " [0, -1],\n", " [1, -1],\n", " [2, -1],\n", " ][:n_clusters]\n", " assert len(centers) == n_clusters\n", " X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_circles(n_clusters):\n", " X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_moons(n_clusters):\n", " X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_noise(n_clusters):\n", " np.random.seed(SEED)\n", " X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))\n", " return normalize(X), labels\n", "\n", "\n", "def get_anisotropic(n_clusters):\n", " X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)\n", " transformation = [[0.6, -0.6], [-0.4, 0.8]]\n", " X = np.dot(X, transformation)\n", " return X, labels\n", "\n", "\n", "def get_varied(n_clusters):\n", " cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]\n", " assert len(cluster_std) == n_clusters\n", " X, labels = make_blobs(\n", " n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED\n", " )\n", " return normalize(X), labels\n", "\n", "\n", "def get_spiral(n_clusters):\n", " # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html\n", " np.random.seed(SEED)\n", " t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))\n", " x = t * np.cos(t)\n", " y = t * np.sin(t)\n", " X = np.concatenate((x, y))\n", " X += 0.7 * np.random.randn(2, N_SAMPLES)\n", " X = np.ascontiguousarray(X.T)\n", "\n", " labels = np.zeros(N_SAMPLES, dtype=int)\n", " return normalize(X), labels\n", "\n", "\n", "DATA_MAPPING = {\n", " 'regular': get_regular,\n", " 'circles': get_circles,\n", " 'moons': get_moons,\n", " 'spiral': get_spiral,\n", " 'noise': get_noise,\n", " 'anisotropic': get_anisotropic,\n", " 'varied': get_varied,\n", "}\n", "\n", "\n", "def get_groundtruth_model(X, labels, n_clusters, **kwargs):\n", " # dummy model to show true label distribution\n", " class Dummy:\n", " def __init__(self, y):\n", " self.labels_ = labels\n", "\n", " return Dummy(labels)\n", "\n", "\n", "def get_kmeans(X, labels, n_clusters, **kwargs):\n", " model = KMeans(init=\"k-means++\", n_clusters=n_clusters, n_init=10, random_state=SEED)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_dbscan(X, labels, n_clusters, **kwargs):\n", " model = DBSCAN(eps=0.3)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_agglomerative(X, labels, n_clusters, **kwargs):\n", " connectivity = kneighbors_graph(\n", " X, n_neighbors=n_clusters, include_self=False\n", " )\n", " # make connectivity symmetric\n", " connectivity = 0.5 * (connectivity + connectivity.T)\n", " model = AgglomerativeClustering(\n", " n_clusters=n_clusters, linkage=\"ward\", connectivity=connectivity\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_meanshift(X, labels, n_clusters, **kwargs):\n", " bandwidth = estimate_bandwidth(X, quantile=0.25)\n", " model = MeanShift(bandwidth=bandwidth, bin_seeding=True)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_spectral(X, labels, n_clusters, **kwargs):\n", " model = SpectralClustering(\n", " n_clusters=n_clusters,\n", " eigen_solver=\"arpack\",\n", " affinity=\"nearest_neighbors\",\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_optics(X, labels, n_clusters, **kwargs):\n", " model = OPTICS(\n", " min_samples=7,\n", " xi=0.05,\n", " min_cluster_size=0.1,\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_birch(X, labels, n_clusters, **kwargs):\n", " model = Birch(n_clusters=n_clusters)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_gaussianmixture(X, labels, n_clusters, **kwargs):\n", " model = GaussianMixture(\n", " n_components=n_clusters, covariance_type=\"full\", random_state=SEED,\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "MODEL_MAPPING = {\n", " 'True labels': get_groundtruth_model,\n", " 'KMeans': get_kmeans,\n", " 'DBSCAN': get_dbscan,\n", " 'MeanShift': get_meanshift,\n", " 'SpectralClustering': get_spectral,\n", " 'OPTICS': get_optics,\n", " 'Birch': get_birch,\n", " 'GaussianMixture': get_gaussianmixture,\n", " 'AgglomerativeClustering': get_agglomerative,\n", "}\n", "\n", "\n", "def plot_clusters(ax, X, labels):\n", " set_clusters = set(labels)\n", " set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately\n", " for label, color in zip(sorted(set_clusters), COLORS):\n", " idx = labels == label\n", " if not sum(idx):\n", " continue\n", " ax.scatter(X[idx, 0], X[idx, 1], color=color)\n", "\n", " # show outliers (if any)\n", " idx = labels == -1\n", " if sum(idx):\n", " ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')\n", "\n", " ax.grid(None)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " return ax\n", "\n", "\n", "def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):\n", " if isinstance(n_clusters, dict):\n", " n_clusters = n_clusters['value']\n", " else:\n", " n_clusters = int(n_clusters)\n", "\n", " X, labels = DATA_MAPPING[dataset](n_clusters)\n", " model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)\n", " if hasattr(model, \"labels_\"):\n", " y_pred = model.labels_.astype(int)\n", " else:\n", " y_pred = model.predict(X)\n", "\n", " fig, ax = plt.subplots(figsize=FIGSIZE)\n", "\n", " plot_clusters(ax, X, y_pred)\n", " ax.set_title(clustering_algorithm, fontsize=16)\n", "\n", " return fig\n", "\n", "\n", "title = \"Clustering with Scikit-learn\"\n", "description = (\n", " \"This example shows how different clustering algorithms work. Simply pick \"\n", " \"the dataset and the number of clusters to see how the clustering algorithms work. \"\n", " \"Colored circles are (predicted) labels and black x are outliers.\"\n", ")\n", "\n", "\n", "def iter_grid(n_rows, n_cols):\n", " # create a grid using gradio Block\n", " for _ in range(n_rows):\n", " with gr.Row():\n", " for _ in range(n_cols):\n", " with gr.Column():\n", " yield\n", "\n", "with gr.Blocks(title=title) as demo:\n", " gr.HTML(f\"{title}\")\n", " gr.Markdown(description)\n", "\n", " input_models = list(MODEL_MAPPING)\n", " input_data = gr.Radio(\n", " list(DATA_MAPPING),\n", " value=\"regular\",\n", " label=\"dataset\"\n", " )\n", " input_n_clusters = gr.Slider(\n", " minimum=1,\n", " maximum=MAX_CLUSTERS,\n", " value=4,\n", " step=1,\n", " label='Number of clusters'\n", " )\n", " n_rows = int(math.ceil(len(input_models) / N_COLS))\n", " counter = 0\n", " for _ in iter_grid(n_rows, N_COLS):\n", " if counter >= len(input_models):\n", " break\n", "\n", " input_model = input_models[counter]\n", " plot = gr.Plot(label=input_model)\n", " fn = partial(cluster, clustering_algorithm=input_model)\n", " input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", " input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", " counter += 1\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}