File size: 9,970 Bytes
ceabdd9
1
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: clustering\n", "### This demo built with Blocks generates 9 plots based on the input.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio matplotlib>=3.5.2 scikit-learn>=1.0.1 "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import math\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.cluster import (\n", "    AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth\n", ")\n", "from sklearn.datasets import make_blobs, make_circles, make_moons\n", "from sklearn.mixture import GaussianMixture\n", "from sklearn.neighbors import kneighbors_graph\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "plt.style.use('seaborn-v0_8')\n", "SEED = 0\n", "MAX_CLUSTERS = 10\n", "N_SAMPLES = 1000\n", "N_COLS = 3\n", "FIGSIZE = 7, 7  # does not affect size in webpage\n", "COLORS = [\n", "    'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'\n", "]\n", "assert len(COLORS) >= MAX_CLUSTERS, \"Not enough different colors for all clusters\"\n", "np.random.seed(SEED)\n", "\n", "\n", "def normalize(X):\n", "    return StandardScaler().fit_transform(X)\n", "\n", "def get_regular(n_clusters):\n", "    # spiral pattern\n", "    centers = [\n", "        [0, 0],\n", "        [1, 0],\n", "        [1, 1],\n", "        [0, 1],\n", "        [-1, 1],\n", "        [-1, 0],\n", "        [-1, -1],\n", "        [0, -1],\n", "        [1, -1],\n", "        [2, -1],\n", "    ][:n_clusters]\n", "    assert len(centers) == n_clusters\n", "    X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)\n", "    return normalize(X), labels\n", "\n", "\n", "def get_circles(n_clusters):\n", "    X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)\n", "    return normalize(X), labels\n", "\n", "\n", "def get_moons(n_clusters):\n", "    X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)\n", "    return normalize(X), labels\n", "\n", "\n", "def get_noise(n_clusters):\n", "    np.random.seed(SEED)\n", "    X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))\n", "    return normalize(X), labels\n", "\n", "\n", "def get_anisotropic(n_clusters):\n", "    X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)\n", "    transformation = [[0.6, -0.6], [-0.4, 0.8]]\n", "    X = np.dot(X, transformation)\n", "    return X, labels\n", "\n", "\n", "def get_varied(n_clusters):\n", "    cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]\n", "    assert len(cluster_std) == n_clusters\n", "    X, labels = make_blobs(\n", "        n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED\n", "    )\n", "    return normalize(X), labels\n", "\n", "\n", "def get_spiral(n_clusters):\n", "    # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html\n", "    np.random.seed(SEED)\n", "    t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))\n", "    x = t * np.cos(t)\n", "    y = t * np.sin(t)\n", "    X = np.concatenate((x, y))\n", "    X += 0.7 * np.random.randn(2, N_SAMPLES)\n", "    X = np.ascontiguousarray(X.T)\n", "\n", "    labels = np.zeros(N_SAMPLES, dtype=int)\n", "    return normalize(X), labels\n", "\n", "\n", "DATA_MAPPING = {\n", "    'regular': get_regular,\n", "    'circles': get_circles,\n", "    'moons': get_moons,\n", "    'spiral': get_spiral,\n", "    'noise': get_noise,\n", "    'anisotropic': get_anisotropic,\n", "    'varied': get_varied,\n", "}\n", "\n", "\n", "def get_groundtruth_model(X, labels, n_clusters, **kwargs):\n", "    # dummy model to show true label distribution\n", "    class Dummy:\n", "        def __init__(self, y):\n", "            self.labels_ = labels\n", "\n", "    return Dummy(labels)\n", "\n", "\n", "def get_kmeans(X, labels, n_clusters, **kwargs):\n", "    model = KMeans(init=\"k-means++\", n_clusters=n_clusters, n_init=10, random_state=SEED)\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_dbscan(X, labels, n_clusters, **kwargs):\n", "    model = DBSCAN(eps=0.3)\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_agglomerative(X, labels, n_clusters, **kwargs):\n", "    connectivity = kneighbors_graph(\n", "        X, n_neighbors=n_clusters, include_self=False\n", "    )\n", "    # make connectivity symmetric\n", "    connectivity = 0.5 * (connectivity + connectivity.T)\n", "    model = AgglomerativeClustering(\n", "        n_clusters=n_clusters, linkage=\"ward\", connectivity=connectivity\n", "    )\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_meanshift(X, labels, n_clusters, **kwargs):\n", "    bandwidth = estimate_bandwidth(X, quantile=0.25)\n", "    model = MeanShift(bandwidth=bandwidth, bin_seeding=True)\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_spectral(X, labels, n_clusters, **kwargs):\n", "    model = SpectralClustering(\n", "        n_clusters=n_clusters,\n", "        eigen_solver=\"arpack\",\n", "        affinity=\"nearest_neighbors\",\n", "    )\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_optics(X, labels, n_clusters, **kwargs):\n", "    model = OPTICS(\n", "        min_samples=7,\n", "        xi=0.05,\n", "        min_cluster_size=0.1,\n", "    )\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_birch(X, labels, n_clusters, **kwargs):\n", "    model = Birch(n_clusters=n_clusters)\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "def get_gaussianmixture(X, labels, n_clusters, **kwargs):\n", "    model = GaussianMixture(\n", "        n_components=n_clusters, covariance_type=\"full\", random_state=SEED,\n", "    )\n", "    model.set_params(**kwargs)\n", "    return model.fit(X)\n", "\n", "\n", "MODEL_MAPPING = {\n", "    'True labels': get_groundtruth_model,\n", "    'KMeans': get_kmeans,\n", "    'DBSCAN': get_dbscan,\n", "    'MeanShift': get_meanshift,\n", "    'SpectralClustering': get_spectral,\n", "    'OPTICS': get_optics,\n", "    'Birch': get_birch,\n", "    'GaussianMixture': get_gaussianmixture,\n", "    'AgglomerativeClustering': get_agglomerative,\n", "}\n", "\n", "\n", "def plot_clusters(ax, X, labels):\n", "    set_clusters = set(labels)\n", "    set_clusters.discard(-1)  # -1 signifiies outliers, which we plot separately\n", "    for label, color in zip(sorted(set_clusters), COLORS):\n", "        idx = labels == label\n", "        if not sum(idx):\n", "            continue\n", "        ax.scatter(X[idx, 0], X[idx, 1], color=color)\n", "\n", "    # show outliers (if any)\n", "    idx = labels == -1\n", "    if sum(idx):\n", "        ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')\n", "\n", "    ax.grid(None)\n", "    ax.set_xticks([])\n", "    ax.set_yticks([])\n", "    return ax\n", "\n", "\n", "def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):\n", "    if isinstance(n_clusters, dict):\n", "        n_clusters = n_clusters['value']\n", "    else:\n", "        n_clusters = int(n_clusters)\n", "\n", "    X, labels = DATA_MAPPING[dataset](n_clusters)\n", "    model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)\n", "    if hasattr(model, \"labels_\"):\n", "        y_pred = model.labels_.astype(int)\n", "    else:\n", "        y_pred = model.predict(X)\n", "\n", "    fig, ax = plt.subplots(figsize=FIGSIZE)\n", "\n", "    plot_clusters(ax, X, y_pred)\n", "    ax.set_title(clustering_algorithm, fontsize=16)\n", "\n", "    return fig\n", "\n", "\n", "title = \"Clustering with Scikit-learn\"\n", "description = (\n", "    \"This example shows how different clustering algorithms work. Simply pick \"\n", "    \"the dataset and the number of clusters to see how the clustering algorithms work. \"\n", "    \"Colored circles are (predicted) labels and black x are outliers.\"\n", ")\n", "\n", "\n", "def iter_grid(n_rows, n_cols):\n", "    # create a grid using gradio Block\n", "    for _ in range(n_rows):\n", "        with gr.Row():\n", "            for _ in range(n_cols):\n", "                with gr.Column():\n", "                    yield\n", "\n", "with gr.Blocks(title=title) as demo:\n", "    gr.HTML(f\"<b>{title}</b>\")\n", "    gr.Markdown(description)\n", "\n", "    input_models = list(MODEL_MAPPING)\n", "    input_data = gr.Radio(\n", "        list(DATA_MAPPING),\n", "        value=\"regular\",\n", "        label=\"dataset\"\n", "    )\n", "    input_n_clusters = gr.Slider(\n", "        minimum=1,\n", "        maximum=MAX_CLUSTERS,\n", "        value=4,\n", "        step=1,\n", "        label='Number of clusters'\n", "    )\n", "    n_rows = int(math.ceil(len(input_models) / N_COLS))\n", "    counter = 0\n", "    for _ in iter_grid(n_rows, N_COLS):\n", "        if counter >= len(input_models):\n", "            break\n", "\n", "        input_model = input_models[counter]\n", "        plot = gr.Plot(label=input_model)\n", "        fn = partial(cluster, clustering_algorithm=input_model)\n", "        input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", "        input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", "        counter += 1\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}