from functools import partial import gradio as gr import matplotlib.pyplot as plt from matplotlib.ticker import NullFormatter import numpy as np from sklearn import datasets, manifold SEED = 0 N_COMPONENTS = 2 np.random.seed(SEED) def get_circles(n_samples): X, color = datasets.make_circles( n_samples=n_samples, factor=0.5, noise=0.05, random_state=SEED ) return X, color def get_s_curve(n_samples): X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED) X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy() return X, color def get_uniform_grid(n_samples): x = np.linspace(0, 1, int(np.sqrt(n_samples))) xx, yy = np.meshgrid(x, x) X = np.hstack( [ xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1), ] ) color = xx.ravel() return X, color DATA_MAPPING = { 'circles': get_circles, 's-curve': get_s_curve, 'uniform grid': get_uniform_grid, } def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool): if isinstance(perplexity, dict): perplexity = perplexity['value'] else: perplexity = int(perplexity) X, color = DATA_MAPPING[dataset](n_samples) if tsne: tsne = manifold.TSNE( n_components=N_COMPONENTS, init="random", random_state=0, perplexity=perplexity, n_iter=400, ) Y = tsne.fit_transform(X) else: Y = X fig, ax = plt.subplots(figsize=(7, 7)) ax.scatter(Y[:, 0], Y[:, 1], c=color) ax.xaxis.set_major_formatter(NullFormatter()) ax.yaxis.set_major_formatter(NullFormatter()) ax.axis("tight") return fig title = "t-SNE: The effect of various perplexity values on the shape" description = ( "An illustration of t-SNE on the two concentric circles and the" "S-curve datasets for different perplexity values." ) with gr.Blocks(title=title) as demo: gr.HTML(f"{title}") gr.Markdown(description) input_data = gr.Radio( list(DATA_MAPPING), value="circles", label="dataset" ) n_samples = gr.Slider( minimum=100, maximum=1000, value=150, step=25, label='Number of Samples' ) perplexity = gr.Slider( minimum=2, maximum=100, value=5, step=1, label='Perplexity' ) with gr.Row(): with gr.Column(): plot = gr.Plot(label="Original data") fn = partial(plot_data, tsne=False) input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) with gr.Column(): plot = gr.Plot(label="t-SNE") fn = partial(plot_data, tsne=True) input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) demo.launch()