import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
import gradio as gr
def apply_clustering(image, num_clusters, target_labels):
data = np.array(image)
pixels = data.reshape(-1, 3)
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(pixels)
labels = kmeans.labels_
new_pixels = np.zeros_like(pixels)
for i in range(num_clusters):
if i in target_labels:
new_pixels[labels == i] = [0, 0, 0]
else:
new_pixels[labels == i] = [255, 255, 255]
new_image_data = new_pixels.reshape(data.shape)
new_image = Image.fromarray(new_image_data)
return new_image
def modify_image(image, num_clusters, target_labels):
data = np.array(image)
pixels = data.reshape(-1, 3)
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(pixels)
labels = kmeans.labels_
label_counts = np.bincount(labels)
cluster_info = []
for i in range(num_clusters):
cluster_pixels = pixels[labels == i]
num_pixels = len(cluster_pixels)
color_hex = '#%02x%02x%02x' % tuple(map(int, kmeans.cluster_centers_[i]))
cluster_info.append((i, num_pixels, color_hex))
modified_image = apply_clustering(image, num_clusters, target_labels)
return modified_image, cluster_info
def gradio_interface(image, num_clusters, cluster_selections):
target_labels = [int(i.split(":")[0]) for i in cluster_selections] # 提取類別數字部分
modified_image, cluster_info = modify_image(image, num_clusters, target_labels)
cluster_markdown = ""
for idx, count, color in cluster_info:
cluster_markdown += f'**Cluster {idx}**: {count} pixels (Color: {color})
\n'
return modified_image, cluster_markdown
def update_checkboxes(image, num_clusters):
_, cluster_info = modify_image(image, num_clusters, [])
options = [f'{i}: {count} pixels (Color: {color})' for i, count, color in cluster_info]
return gr.update(choices=options, value=[])
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
def sort_clusters_by_color(image, num_clusters, hex_inputs, current_selections):
hex_colors = hex_inputs.split(',')
rgb_colors = [hex_to_rgb(hex_color) for hex_color in hex_colors]
_, cluster_info = modify_image(image, num_clusters, [])
selected_labels = set()
for rgb in rgb_colors:
for i, _, color in cluster_info:
cluster_rgb = hex_to_rgb(color)
distance = np.linalg.norm(np.array(cluster_rgb) - np.array(rgb))
if distance < 64:
selected_labels.add(i)
options = [f'{i}: {count} pixels (Color: {color})' for i, count, color in cluster_info]
updated_selections = [option for option in options if int(option.split(":")[0]) in selected_labels]
return gr.update(choices=options, value=updated_selections)
num_clusters = 5
with gr.Blocks() as demo:
with gr.Column():
image_output = gr.Image(type="pil", label="Modified Image")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
hex_input = gr.Textbox(label="Hex Colors (e.g., #000100,#E51B1B)", placeholder="Enter hex colors separated by commas")
sort_button = gr.Button("Sort by Color")
num_clusters_slider = gr.Slider(minimum=1, maximum=20, value=num_clusters, label="Number of Clusters", step=1)
cluster_selection = gr.CheckboxGroup(choices=[], label="Target Clusters")
sort_button.click(fn=sort_clusters_by_color, inputs=[image_input, num_clusters_slider, hex_input, cluster_selection], outputs=cluster_selection)
with gr.Row():
cluster_info_output = gr.Markdown(label="Cluster Information")
image_input.change(fn=update_checkboxes, inputs=[image_input, num_clusters_slider], outputs=cluster_selection)
num_clusters_slider.change(fn=update_checkboxes, inputs=[image_input, num_clusters_slider], outputs=cluster_selection)
cluster_selection.change(fn=gradio_interface, inputs=[image_input, num_clusters_slider, cluster_selection], outputs=[image_output, cluster_info_output])
demo.launch()