semantic_clustering / twc_clustering.py
taskswithcode
Fixes
5fe6115
raw
history blame
6.9 kB
from scipy.spatial.distance import cosine
import argparse
import json
import pdb
import torch
import torch.nn.functional as F
import numpy as np
import time
from collections import OrderedDict
class TWCClustering:
def __init__(self):
print("In Zscore Clustering")
def compute_matrix(self,embeddings):
#print("Computing similarity matrix ...)")
embeddings= np.array(embeddings)
start = time.time()
vec_a = embeddings.T #vec_a shape (1024,)
vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows
vec_a = vec_a.T #vec_a shape becomes (,1024)
similarity_matrix = np.inner(vec_a,vec_a)
end = time.time()
time_val = (end-start)*1000
#print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes")
return similarity_matrix
def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold):
run_index = pivot_index
picked_arr = []
while (run_index < len(embeddings)):
if (matrix[pivot_index][run_index] >= threshold):
picked_arr.append(run_index)
run_index += 1
return picked_arr
def update_picked_dict_arr(self,picked_dict,arr):
for i in range(len(arr)):
picked_dict[arr[i]] = 1
def update_picked_dict(self,picked_dict,in_dict):
for key in in_dict:
picked_dict[key] = 1
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
center_index = pivot_index
center_score = 0
center_dict = {}
for i in range(len(arr)):
node_i_index = arr[i]
running_score = 0
temp_dict = {}
for j in range(len(arr)):
node_j_index = arr[j]
cosine_dist = matrix[node_i_index][node_j_index]
if ((cosine_dist < threshold) and strict_cluster):
continue
running_score += cosine_dist
temp_dict[node_j_index] = cosine_dist
if (running_score > center_score):
center_index = node_i_index
center_dict = temp_dict
center_score = running_score
sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True))
return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d}
def update_overlap_stats(self,overlap_dict,cluster_info):
arr = list(cluster_info["neighs"].keys())
for val in arr:
if (val not in overlap_dict):
overlap_dict[val] = 1
else:
overlap_dict[val] += 1
def bucket_overlap(self,overlap_dict):
bucket_dict = {}
for key in overlap_dict:
if (overlap_dict[key] not in bucket_dict):
bucket_dict[overlap_dict[key]] = 1
else:
bucket_dict[overlap_dict[key]] += 1
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
return sorted_d
def merge_clusters(self,ref_cluster,curr_cluster):
dup_arr = ref_cluster.copy()
for j in range(len(curr_cluster)):
if (curr_cluster[j] not in dup_arr):
ref_cluster.append(curr_cluster[j])
def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
candidates = []
for i in range(len(embeddings)):
if (i in picked_dict):
continue
zscore = mean + threshold*std
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
candidates.append(arr)
self.update_picked_dict_arr(picked_dict,arr)
# Merge arrays to create non-overlapping sets
run_index_i = 0
while (run_index_i < len(candidates)):
ref_cluster = candidates[run_index_i]
run_index_j = run_index_i + 1
found = False
while (run_index_j < len(candidates)):
curr_cluster = candidates[run_index_j]
for k in range(len(curr_cluster)):
if (curr_cluster[k] in ref_cluster):
self.merge_clusters(ref_cluster,curr_cluster)
candidates.pop(run_index_j)
found = True
run_index_i = 0
break
if (found):
break
else:
run_index_j += 1
if (not found):
run_index_i += 1
zscore = mean + threshold*std
for i in range(len(candidates)):
arr = candidates[i]
cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
cluster_dict["clusters"].append(cluster_info)
return {}
def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
zscore = mean + threshold*std
for i in range(len(embeddings)):
if (i in picked_dict):
continue
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
self.update_picked_dict(picked_dict,cluster_info["neighs"])
self.update_overlap_stats(overlap_dict,cluster_info)
cluster_dict["clusters"].append(cluster_info)
sorted_d = self.bucket_overlap(overlap_dict)
return sorted_d
def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
is_overlapped = True if clustering_type == "overlapped" else False
matrix = self.compute_matrix(embeddings)
mean = np.mean(matrix)
std = np.std(matrix)
zscores = []
inc = 0
value = mean
while (value < 1):
zscores.append({"threshold":inc,"cosine":round(value,2)})
inc += 1
value = mean + inc*std
#print("In clustering:",round(std,2),zscores)
cluster_dict = {}
cluster_dict["clusters"] = []
if (is_overlapped):
sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
else:
sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
return cluster_dict