import numpy as np from PIL import Image from sklearn.cluster import KMeans import gradio as gr def apply_clustering(image, num_clusters, target_labels): data = np.array(image) pixels = data.reshape(-1, 3) kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(pixels) labels = kmeans.labels_ new_pixels = np.zeros_like(pixels) for i in range(num_clusters): if i in target_labels: new_pixels[labels == i] = [0, 0, 0] else: new_pixels[labels == i] = [255, 255, 255] new_image_data = new_pixels.reshape(data.shape) new_image = Image.fromarray(new_image_data) return new_image def modify_image(image, num_clusters, target_labels): data = np.array(image) pixels = data.reshape(-1, 3) kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(pixels) labels = kmeans.labels_ label_counts = np.bincount(labels) cluster_info = [] for i in range(num_clusters): cluster_pixels = pixels[labels == i] num_pixels = len(cluster_pixels) color_hex = '#%02x%02x%02x' % tuple(map(int, kmeans.cluster_centers_[i])) cluster_info.append((i, num_pixels, color_hex)) modified_image = apply_clustering(image, num_clusters, target_labels) return modified_image, cluster_info def gradio_interface(image, num_clusters, cluster_selections): target_labels = [int(i.split(":")[0]) for i in cluster_selections] # 提取類別數字部分 modified_image, cluster_info = modify_image(image, num_clusters, target_labels) cluster_markdown = "" for idx, count, color in cluster_info: cluster_markdown += f'**Cluster {idx}**: {count} pixels (Color: {color})
\n' return modified_image, cluster_markdown def update_checkboxes(image, num_clusters): _, cluster_info = modify_image(image, num_clusters, []) options = [f'{i}: {count} pixels (Color: {color})' for i, count, color in cluster_info] return gr.update(choices=options, value=[]) def hex_to_rgb(hex_color): hex_color = hex_color.lstrip('#') return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) def sort_clusters_by_color(image, num_clusters, hex_inputs, current_selections): hex_colors = hex_inputs.split(',') rgb_colors = [hex_to_rgb(hex_color) for hex_color in hex_colors] _, cluster_info = modify_image(image, num_clusters, []) selected_labels = set() for rgb in rgb_colors: for i, _, color in cluster_info: cluster_rgb = hex_to_rgb(color) distance = np.linalg.norm(np.array(cluster_rgb) - np.array(rgb)) if distance < 64: selected_labels.add(i) options = [f'{i}: {count} pixels (Color: {color})' for i, count, color in cluster_info] updated_selections = [option for option in options if int(option.split(":")[0]) in selected_labels] return gr.update(choices=options, value=updated_selections) num_clusters = 5 with gr.Blocks() as demo: with gr.Column(): image_output = gr.Image(type="pil", label="Modified Image") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") with gr.Column(): hex_input = gr.Textbox(label="Hex Colors (e.g., #000100,#E51B1B)", placeholder="Enter hex colors separated by commas") sort_button = gr.Button("Sort by Color") num_clusters_slider = gr.Slider(minimum=1, maximum=20, value=num_clusters, label="Number of Clusters", step=1) cluster_selection = gr.CheckboxGroup(choices=[], label="Target Clusters") sort_button.click(fn=sort_clusters_by_color, inputs=[image_input, num_clusters_slider, hex_input, cluster_selection], outputs=cluster_selection) with gr.Row(): cluster_info_output = gr.Markdown(label="Cluster Information") image_input.change(fn=update_checkboxes, inputs=[image_input, num_clusters_slider], outputs=cluster_selection) num_clusters_slider.change(fn=update_checkboxes, inputs=[image_input, num_clusters_slider], outputs=cluster_selection) cluster_selection.change(fn=gradio_interface, inputs=[image_input, num_clusters_slider, cluster_selection], outputs=[image_output, cluster_info_output]) demo.launch()