File size: 2,381 Bytes
f797ad8
 
 
e889613
 
fb8832c
e889613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f797ad8
 
 
 
 
 
 
 
 
 
7b0ac85
f797ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import cv2
from PIL import Image
from skimage.color import rgb2lab
from skimage.color import lab2rgb
from sklearn.cluster import KMeans

def color_quantization(image, n_colors):
    # Convert image to LAB color space
    lab_image = rgb2lab(image)
    # Reshape image to 2D array of pixels
    pixels = lab_image.reshape(-1, 3)
    # Perform K-means clustering
    kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(pixels)
    # Replace each pixel with the closest color
    labels = kmeans.predict(pixels)
    colors = kmeans.cluster_centers_
    quantized_pixels = colors[labels]
    # Convert quantized image back to RGB color space
    quantized_lab_image = quantized_pixels.reshape(lab_image.shape)
    quantized_rgb_image = lab2rgb(quantized_lab_image)
    return (quantized_rgb_image * 255).astype(np.uint8)
    
def get_high_freq_colors(image):
  im = image.getcolors(maxcolors=1024*1024)
  sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
  
  freqs = [c[0] for c in sorted_colors]
  mean_freq = sum(freqs) / len(freqs)

  high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)]  # Ignore colors that occur very few times (less than 2) or less than half the average frequency
  return high_freq_colors

def color_quantization_old(image, n_colors):
    # Get color histogram
    hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
    # Get most frequent colors
    colors = np.argwhere(hist > 0)
    colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
    colors = colors[:n_colors]
    # Replace each pixel with the closest color
    dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
    labels = np.argmin(dists, axis=1)
    return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)

def create_binary_matrix(img_arr, target_color):
    # Create mask of pixels with target color
    mask = np.all(img_arr == target_color, axis=-1)
    
    # Convert mask to binary matrix
    binary_matrix = mask.astype(int)
    from datetime import datetime
    binary_file_name = f'mask-{datetime.now().timestamp()}.png'
    cv2.imwrite(binary_file_name, binary_matrix * 255)
    
    #binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0)
    return binary_file_name