|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from functools import partial |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
import math, random |
|
|
|
|
|
|
|
def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): |
|
N,C,H,W = data_vecs.shape |
|
assert N == 1, 'only support singe image tensor' |
|
|
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) |
|
|
|
data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy() |
|
km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300) |
|
pred = km.fit_predict(data_vecs_np) |
|
cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device) |
|
id_maps = cluster_ids_x.reshape(1,1,H,W).long() |
|
if need_layer_masks: |
|
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() |
|
cluster_mask = one_hot_labels.permute(0,3,1,2) |
|
return cluster_mask |
|
return id_maps |
|
|
|
|
|
def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20): |
|
N,C,H,W = data_vecs.shape |
|
assert N == 1, 'only support singe image tensor' |
|
|
|
|
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) |
|
|
|
|
|
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ |
|
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) |
|
id_maps = cluster_ids_x.reshape(1,1,H,W) |
|
if need_layer_masks: |
|
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float() |
|
cluster_mask = one_hot_labels.permute(0,3,1,2) |
|
return cluster_mask |
|
return id_maps |
|
|
|
|
|
def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False): |
|
N,C,H,W = data_vecs.shape |
|
sample_list = [] |
|
for idx in range(N): |
|
if use_sklearn_kmeans: |
|
cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) |
|
else: |
|
cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True) |
|
sample_list.append(cluster_mask) |
|
return torch.cat(sample_list, dim=0) |
|
|
|
|
|
def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20): |
|
N,C,H,W = data_vecs.shape |
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C) |
|
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\ |
|
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device) |
|
return cluster_centers |
|
|
|
|
|
def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'): |
|
N,C,H,W = data_tensor.shape |
|
centroid_list = [] |
|
for idx in range(N): |
|
cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric) |
|
centroid_list.append(cluster_centers) |
|
|
|
batch_centroids = torch.stack(centroid_list, dim=0) |
|
data_vecs = data_tensor.flatten(2) |
|
|
|
AtB = torch.matmul(batch_centroids, data_vecs) |
|
AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1)) |
|
BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs) |
|
diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1) |
|
diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1) |
|
A2 = diag_A.unsqueeze(2).repeat(1,1,H*W) |
|
B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1) |
|
distance_map = A2 - 2*AtB + B2 |
|
values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True) |
|
cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map)) |
|
cluster_mask = cluster_mask.view(N,n_clusters,H,W) |
|
return cluster_mask |
|
|
|
|
|
|
|
''' |
|
resource from github: https://github.com/subhadarship/kmeans_pytorch |
|
''' |
|
|
|
|
|
def initialize(X, num_clusters): |
|
""" |
|
initialize cluster centers |
|
:param X: (torch.tensor) matrix |
|
:param num_clusters: (int) number of clusters |
|
:return: (np.array) initial state |
|
""" |
|
np.random.seed(1) |
|
num_samples = len(X) |
|
indices = np.random.choice(num_samples, num_clusters, replace=False) |
|
initial_state = X[indices] |
|
return initial_state |
|
|
|
|
|
def kmeans( |
|
X, |
|
num_clusters, |
|
distance='euclidean', |
|
cluster_centers=[], |
|
tol=1e-4, |
|
tqdm_flag=True, |
|
iter_limit=0, |
|
device=torch.device('cpu'), |
|
gamma_for_soft_dtw=0.001 |
|
): |
|
""" |
|
perform kmeans |
|
:param X: (torch.tensor) matrix |
|
:param num_clusters: (int) number of clusters |
|
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] |
|
:param tol: (float) threshold [default: 0.0001] |
|
:param device: (torch.device) device [default: cpu] |
|
:param tqdm_flag: Allows to turn logs on and off |
|
:param iter_limit: hard limit for max number of iterations |
|
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 |
|
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers |
|
""" |
|
if tqdm_flag: |
|
print(f'running k-means on {device}..') |
|
|
|
if distance == 'euclidean': |
|
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) |
|
elif distance == 'cosine': |
|
pairwise_distance_function = partial(pairwise_cosine, device=device) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
X = X.float() |
|
|
|
|
|
X = X.to(device) |
|
|
|
|
|
if type(cluster_centers) == list: |
|
initial_state = initialize(X, num_clusters) |
|
else: |
|
if tqdm_flag: |
|
print('resuming') |
|
|
|
initial_state = cluster_centers |
|
dis = pairwise_distance_function(X, initial_state) |
|
choice_points = torch.argmin(dis, dim=0) |
|
initial_state = X[choice_points] |
|
initial_state = initial_state.to(device) |
|
|
|
iteration = 0 |
|
if tqdm_flag: |
|
tqdm_meter = tqdm(desc='[running kmeans]') |
|
while True: |
|
|
|
dis = pairwise_distance_function(X, initial_state) |
|
|
|
choice_cluster = torch.argmin(dis, dim=1) |
|
|
|
initial_state_pre = initial_state.clone() |
|
|
|
for index in range(num_clusters): |
|
selected = torch.nonzero(choice_cluster == index).squeeze().to(device) |
|
|
|
selected = torch.index_select(X, 0, selected) |
|
|
|
|
|
if selected.shape[0] == 0: |
|
selected = X[torch.randint(len(X), (1,))] |
|
|
|
initial_state[index] = selected.mean(dim=0) |
|
|
|
center_shift = torch.sum( |
|
torch.sqrt( |
|
torch.sum((initial_state - initial_state_pre) ** 2, dim=1) |
|
)) |
|
|
|
|
|
iteration = iteration + 1 |
|
|
|
|
|
if tqdm_flag: |
|
tqdm_meter.set_postfix( |
|
iteration=f'{iteration}', |
|
center_shift=f'{center_shift ** 2:0.6f}', |
|
tol=f'{tol:0.6f}' |
|
) |
|
tqdm_meter.update() |
|
if center_shift ** 2 < tol: |
|
break |
|
if iter_limit != 0 and iteration >= iter_limit: |
|
|
|
break |
|
|
|
return choice_cluster.to(device), initial_state.to(device) |
|
|
|
|
|
def kmeans_predict( |
|
X, |
|
cluster_centers, |
|
distance='euclidean', |
|
device=torch.device('cpu'), |
|
gamma_for_soft_dtw=0.001, |
|
tqdm_flag=True |
|
): |
|
""" |
|
predict using cluster centers |
|
:param X: (torch.tensor) matrix |
|
:param cluster_centers: (torch.tensor) cluster centers |
|
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] |
|
:param device: (torch.device) device [default: 'cpu'] |
|
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 |
|
:return: (torch.tensor) cluster ids |
|
""" |
|
if tqdm_flag: |
|
print(f'predicting on {device}..') |
|
|
|
if distance == 'euclidean': |
|
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) |
|
elif distance == 'cosine': |
|
pairwise_distance_function = partial(pairwise_cosine, device=device) |
|
elif distance == 'soft_dtw': |
|
sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) |
|
pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
X = X.float() |
|
|
|
|
|
X = X.to(device) |
|
|
|
dis = pairwise_distance_function(X, cluster_centers) |
|
choice_cluster = torch.argmin(dis, dim=1) |
|
|
|
return choice_cluster.cpu() |
|
|
|
|
|
def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): |
|
if tqdm_flag: |
|
print(f'device is :{device}') |
|
|
|
|
|
data1, data2 = data1.to(device), data2.to(device) |
|
|
|
|
|
A = data1.unsqueeze(dim=1) |
|
|
|
|
|
B = data2.unsqueeze(dim=0) |
|
|
|
dis = (A - B) ** 2.0 |
|
|
|
dis = dis.sum(dim=-1).squeeze() |
|
return dis |
|
|
|
|
|
def pairwise_cosine(data1, data2, device=torch.device('cpu')): |
|
|
|
data1, data2 = data1.to(device), data2.to(device) |
|
|
|
|
|
A = data1.unsqueeze(dim=1) |
|
|
|
|
|
B = data2.unsqueeze(dim=0) |
|
|
|
|
|
A_normalized = A / A.norm(dim=-1, keepdim=True) |
|
B_normalized = B / B.norm(dim=-1, keepdim=True) |
|
|
|
cosine = A_normalized * B_normalized |
|
|
|
|
|
cosine_dis = 1 - cosine.sum(dim=-1).squeeze() |
|
return cosine_dis |