import numpy as np |
import pandas as pd |
from openTSNE import TSNE |
import plotly.graph_objs as go |
import matplotlib.pyplot as plt |
import matplotlib.colors as mcolors |
from sklearn.decomposition import PCA |
from scipy.optimize import linear_sum_assignment |
class TSNE_Plot(): |
def __init__(self, sentence, embed, label = None, n_clusters :int = 3, n_annotation_positions:int = 20): |
assert n_clusters > 0, "N must be greater than 0" |
self.N = n_clusters |
self.test_X = pd.DataFrame({"text": sentence, "embed": [np.array(i) for i in embed]}) |
self.test_y = pd.DataFrame({'label':label}) if label is not None else pd.DataFrame({"label": self.cluster()}) |
self.embed = self.calculate_tsne() |
self.init_df() |
self.n_annotation_positions = n_annotation_positions |
self.show_sentence = [] |
self.random_sentence() |
self.annotation_positions = [] |
self.get_annotation_positions() |
self.mapping = {} |
def cluster(self): |
from sklearn.cluster import KMeans |
n_components = min(50, len(self.test_X)) |
pca = PCA(n_components=n_components) |
compact_embedding = pca.fit_transform(np.array(self.test_X["embed"].tolist())) |
kmeans = KMeans(n_clusters=self.N) |
kmeans.fit(compact_embedding) |
labels = kmeans.labels_ |
return labels |
def generate_colormap(self, n_labels): |
color_norm = mcolors.Normalize(vmin=0, vmax=len(n_labels) - 1) |
scalar_map = plt.cm.ScalarMappable(norm=color_norm, cmap='jet') |
colormap = {} |
for label in range(len(n_labels)): |
color_hex = mcolors.to_hex(scalar_map.to_rgba(label)) |
colormap[n_labels[label]] = color_hex |
return colormap |
def divide_hex_color_by_half(self, hex_color): |
if len(hex_color) > 0 and hex_color[0] == "#": |
hex_color = hex_color[1:] |
red_hex, green_hex, blue_hex = hex_color[0:2], hex_color[2:4], hex_color[4:6] |
red_half = int(red_hex, 16) // 10 + (255-25) |
green_half = int(green_hex, 16) // 10 + (255-25) |
blue_half = int(blue_hex, 16) // 10 + (255-25) |
half_hex_color = "#{:02x}{:02x}{:02x}".format(red_half, green_half, blue_half) |
return half_hex_color |
def get_annotation_positions(self): |
min_x, max_x = self.df['x'].min()-1, self.df['x'].max()+2 |
n = self.n_annotation_positions |
y_min, y_max = self.df['y'].min() * 3, self.df['y'].max() * 3 |
add = 0 if n % 2 == 0 else 1 |
y_values = np.linspace(y_min, y_max, n//2+add) |
left_positions = [(min_x, y) for y in y_values] |
right_positions = [(max_x, y) for y in y_values] |
self.annotation_positions = left_positions + right_positions |
def euclidean_distance(self, p1, p2): |
return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2) |
def map_points(self): |
points1 = [(self.embed[i][0], self.embed[i][1]) for i in self.show_sentence] |
distance_matrix = np.zeros((len(points1), len(self.annotation_positions))) |
for i, point1 in enumerate(points1): |
for j, point2 in enumerate(self.annotation_positions): |
distance_matrix[i, j] = self.euclidean_distance(point1, point2) |
row_ind, col_ind = linear_sum_assignment(distance_matrix) |
for i, j in zip(row_ind, col_ind): |
self.mapping[self.show_sentence[i]] = self.annotation_positions[j] |
def show_text(self, show_sentence, text): |
sentence = [] |
for i in range(len(text)): |
if i in show_sentence: |
s = text[i][:10] + "..." + text[i][-10:] |
sentence.append(s) |
else: |
sentence.append("") |
return sentence |
def init_df(self): |
X, Y = np.split(self.embed, 2, axis=1) |
data = { |
"x": X.flatten(), |
"y": Y.flatten(), |
} |
self.df = pd.DataFrame(data) |
def format_data(self): |
sentence = self.show_text(self.show_sentence, self.test_X["text"]) |
X, Y = np.split(self.embed, 2, axis=1) |
n = len(self.test_X) |
data = { |
"x": X.flatten(), |
"y": Y.flatten(), |
"label": self.test_y["label"], |
"sentence" : sentence, |
"size" : [20 if i in self.show_sentence else 10 for i in range(n)], |
"pos" : [{"x_offset": self.mapping.get(i, (0, 0))[0], "y_offset": self.mapping.get(i, (0, 0))[1]} for i in range(n)], |
"annotate" : [True if i in self.show_sentence else False for i in range(n)], |
} |
self.df = pd.DataFrame(data) |
def calculate_tsne(self): |
embed = np.array(self.test_X["embed"].tolist()) |
n_components = min(50, len(self.test_X)) |
pca = PCA(n_components=n_components) |
compact_embedding = pca.fit_transform(embed) |
tsne = TSNE( |
perplexity=30, |
metric="cosine", |
n_jobs=8, |
random_state=42, |
verbose=False, |
) |
embedding_train = tsne.fit(compact_embedding) |
embedding_train = embedding_train.optimize(n_iter=1000, momentum=0.8) |
return embedding_train |
def random_sentence(self): |
n_samples = len(self.test_y) |
show_sentence = [] |
while len(show_sentence) < self.n_annotation_positions: |
show_sentence.append(np.random.randint(0, n_samples)) |
show_sentence = list(set(show_sentence)) |
label_count = self.test_y["label"].value_counts() |
max_label = label_count.index[0] |
max_count = label_count[0] |
for i in range(max_count): |
for j in range(len(label_count)): |
if label_count[j] == i: |
show_sentence.append(self.test_y[self.test_y["label"] == label_count.index[j]].index[0]) |
self.show_sentence = list(set(show_sentence)) |
def plot(self, return_fig=False): |
min_x, max_x = self.df['x'].min()-1, self.df['x'].max()+2 |
fig = go.Figure() |
fig = go.Figure(layout=go.Layout( |
autosize=False, |
height=800, |
width=1500, |
)) |
label_colors = self.generate_colormap(self.df['label'].unique()) |
line_legend_group = "lines" |
for label, color in label_colors.items(): |
mask = self.df["label"] == label |
fig.add_trace(go.Scatter(x=self.df[mask]['x'], y=self.df[mask]['y'], mode='markers', |
marker=dict(color=color, size=self.df[mask]['size']), |
showlegend=True, legendgroup=line_legend_group, |
name = str(label)) |
) |
for x, y, label, sentence, pos, annotate in zip(self.df.x, self.df.y, self.df.label, self.df.sentence, self.df.pos, self.df.annotate): |
if not sentence: |
continue |
if not annotate: |
continue |
criteria = (pos["x_offset"] - min_x) < 1e-2 |
sentence_annotation = dict( |
x=pos["x_offset"], |
y=pos["y_offset"], |
xref="x", |
yref="y", |
text=sentence, |
showarrow=False, |
xanchor="right" if criteria else 'left', |
yanchor='middle', |
font=dict(color="black"), |
bordercolor=label_colors.get(label, "black"), |
borderpad=2, |
bgcolor=self.divide_hex_color_by_half(label_colors.get(label, "black")) |
) |
fig.add_annotation(sentence_annotation) |
x_start = x - 1 if criteria else x + 1 |
x_turn = x - 0.5 if criteria else x + 0.5 |
y_turn = y |
fig.add_trace(go.Scatter(x=[pos["x_offset"], x_start, x_turn, x], y=[pos["y_offset"], pos["y_offset"], y_turn, y], mode='lines', |
line=dict(color=label_colors.get(label, "black")), showlegend=False, legendgroup=line_legend_group)) |
fig.update_xaxes(tickvals=[]) |
fig.update_yaxes(tickvals=[]) |
if not return_fig: |
fig.show() |
else: |
return fig |
def tsne_plot(self, n_sentence = 20, return_fig=False): |
embedding_train = self.calculate_tsne() |
if self.n_annotation_positions != min(n_sentence, len(self.test_y)): |
self.n_annotation_positions = min(n_sentence, len(self.test_y)) |
self.random_sentence() |
self.get_annotation_positions() |
self.map_points() |
self.format_data() |
if not return_fig: |
self.plot() |
else: |
return self.plot(return_fig=return_fig) |