from lollms.config import TypedConfig, BaseConfig, ConfigTemplate, InstallOption |
from lollms.types import MSG_TYPE |
from lollms.personality import APScript, AIPersonality |
from lollms.helpers import ASCIIColors |
import numpy as np |
import json |
from pathlib import Path |
class TextVectorizer: |
def __init__(self, model_name, database_file:Path|str, visualize_data_at_startup=False, visualize_data_at_add_file=False, visualize_data_at_generate=False): |
from transformers import AutoTokenizer, AutoModel |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
self.model = AutoModel.from_pretrained(model_name) |
self.embeddings = {} |
self.texts = {} |
self.ready = False |
self.database_file = Path(database_file) |
self.visualize_data_at_startup = visualize_data_at_startup |
self.visualize_data_at_add_file = visualize_data_at_add_file |
self.visualize_data_at_generate = visualize_data_at_generate |
if Path(self.database_file).exists(): |
ASCIIColors.success(f"Database file found : {self.database_file}") |
self.load_from_json() |
if visualize_data_at_startup: |
self.show_document() |
self.ready = True |
else: |
ASCIIColors.info(f"No database file found : {self.database_file}") |
def show_document(self, query_text="What is the main idea of this text?", use_pca=True): |
import textwrap |
import seaborn as sns |
import matplotlib.pyplot as plt |
import mplcursors |
from tkinter import Tk, Text, Scrollbar, Frame, Label, TOP, BOTH, RIGHT, LEFT, Y, N, END |
from sklearn.manifold import TSNE |
from sklearn.decomposition import PCA |
import torch |
if use_pca: |
print("Showing pca representation :") |
else: |
print("Showing t-sne representation :") |
texts = list(self.texts.values()) |
embeddings = torch.stack(list(self.embeddings.values())).detach().squeeze(1).numpy() |
norms = np.linalg.norm(embeddings, axis=1) |
normalized_embeddings = embeddings / norms[:, np.newaxis] |
query_embedding = self.embed_query(query_text) |
query_embedding = query_embedding.detach().squeeze().numpy() |
query_normalized_embedding = query_embedding / np.linalg.norm(query_embedding) |
combined_embeddings = np.vstack((normalized_embeddings, query_normalized_embedding)) |
if use_pca: |
pca = PCA(n_components=2) |
embeddings_2d = pca.fit_transform(combined_embeddings) |
else: |
perplexity = min(30, combined_embeddings.shape[0] - 1) |
tsne = TSNE(n_components=2, perplexity=perplexity) |
embeddings_2d = tsne.fit_transform(combined_embeddings) |
sns.scatterplot(x=embeddings_2d[:-1, 0], y=embeddings_2d[:-1, 1]) |
plt.scatter(embeddings_2d[-1, 0], embeddings_2d[-1, 1], color='red') |
for i, (x, y) in enumerate(embeddings_2d[:-1]): |
plt.text(x, y, str(i), fontsize=8) |
plt.xlabel('Dimension 1') |
plt.ylabel('Dimension 2') |
if use_pca: |
plt.title('Embeddings Scatter Plot based on PCA') |
else: |
plt.title('Embeddings Scatter Plot based on t-SNE') |
cursor = mplcursors.cursor(hover=True) |
@cursor.connect("add") |
def on_hover(sel): |
index = sel.target.index |
if index > 0: |
text = texts[index] |
wrapped_text = textwrap.fill(text, width=50) |
sel.annotation.set_text(f"Index: {index}\nText:\n{wrapped_text}") |
else: |
sel.annotation.set_text("Query") |
def on_click(event): |
if event.xdata is not None and event.ydata is not None: |
x, y = event.xdata, event.ydata |
distances = ((embeddings_2d[:, 0] - x) ** 2 + (embeddings_2d[:, 1] - y) ** 2) |
index = distances.argmin() |
text = texts[index] if index < len(texts) else query_text |
root = Tk() |
root.title(f"Text for Index {index}") |
frame = Frame(root) |
frame.pack(fill=BOTH, expand=True) |
label = Label(frame, text="Text:") |
label.pack(side=TOP, padx=5, pady=5) |
text_box = Text(frame) |
text_box.pack(side=TOP, padx=5, pady=5, fill=BOTH, expand=True) |
text_box.insert(END, text) |
scrollbar = Scrollbar(frame) |
scrollbar.pack(side=RIGHT, fill=Y) |
scrollbar.config(command=text_box.yview) |
text_box.config(yscrollcommand=scrollbar.set) |
text_box.config(state="disabled") |
root.mainloop() |
plt.gcf().canvas.mpl_connect("button_press_event", on_click) |
plt.show() |
def index_document(self, document_id, text, chunk_size, overlap_size, force_vectorize=False): |
import torch |
if document_id in self.embeddings and not force_vectorize: |
print(f"Document {document_id} already exists. Skipping vectorization.") |
return |
tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_attention_mask=False)['input_ids'] |
sentences = self.tokenizer.decode(tokens).split('. ') |
chunks = [] |
current_chunk = [] |
for sentence in sentences: |
sentence_tokens = self.tokenizer.encode_plus(sentence, add_special_tokens=False, return_attention_mask=False)['input_ids'] |
if len(current_chunk) + len(sentence_tokens) <= chunk_size: |
current_chunk.extend(sentence_tokens) |
else: |
if current_chunk: |
chunks.append(current_chunk) |
current_chunk = sentence_tokens |
if current_chunk: |
chunks.append(current_chunk) |
overlapping_chunks = [] |
for i in range(len(chunks)): |
chunk_start = i * (chunk_size - overlap_size) |
chunk_end = min(chunk_start + chunk_size, len(tokens)) |
chunk = tokens[chunk_start:chunk_end] |
overlapping_chunks.append(chunk) |
for i, chunk in enumerate(overlapping_chunks): |
if len(chunk) < chunk_size: |
padding = [self.tokenizer.pad_token_id] * (chunk_size - len(chunk)) |
chunk.extend(padding) |
input_ids = chunk[:chunk_size] |
input_tensor = torch.tensor([input_ids]) |
with torch.no_grad(): |
self.model.eval() |
outputs = self.model(input_tensor) |
embeddings = outputs.last_hidden_state.mean(dim=1) |
chunk_id = f"{document_id}_chunk_{i + 1}" |
self.embeddings[chunk_id] = embeddings |
self.texts[chunk_id] = self.tokenizer.decode(chunk[:chunk_size], skip_special_tokens=True) |
self.save_to_json() |
self.ready = True |
if self.visualize_data_at_add_file: |
self.show_document() |
def embed_query(self, query_text): |
import torch |
query_tokens = self.tokenizer.encode(query_text) |
query_input_tensor = torch.tensor([query_tokens]) |
with torch.no_grad(): |
self.model.eval() |
query_outputs = self.model(query_input_tensor) |
query_embedding = query_outputs.last_hidden_state.mean(dim=1) |
return query_embedding |
def recover_text(self, query_embedding, top_k=1): |
from sklearn.metrics.pairwise import cosine_similarity |
similarities = {} |
for chunk_id, chunk_embedding in self.embeddings.items(): |
similarity = cosine_similarity(query_embedding.numpy(), chunk_embedding.numpy())[0][0] |
similarities[chunk_id] = similarity |
sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k] |
texts = [self.texts[chunk_id] for chunk_id, _ in sorted_similarities] |
if self.visualize_data_at_generate: |
self.show_document() |
return texts |
def save_to_json(self): |
state = { |
"embeddings": {str(k): v.tolist() for k, v in self.embeddings.items()}, |
"texts": self.texts, |
} |
with open(self.database_file, "w") as f: |
json.dump(state, f) |
def load_from_json(self): |
import torch |
ASCIIColors.info("Loading vectorized documents") |
with open(self.database_file, "r") as f: |
state = json.load(f) |
self.embeddings = {k: torch.tensor(v) for k, v in state["embeddings"].items()} |
self.texts = state["texts"] |
self.ready = True |
class Processor(APScript): |
""" |
A class that processes model inputs and outputs. |
Inherits from APScript. |
""" |
def __init__( |
self, |
personality: AIPersonality |
) -> None: |
self.word_callback = None |
personality_config_template = ConfigTemplate( |
[ |
{"name":"database_path","type":"str","value":f"{personality.name}_db.json", "help":"Path to the database"}, |
{"name":"max_chunk_size","type":"int","value":512, "min":10, "max":personality.config["ctx_size"],"help":"Maximum size of text chunks to vectorize"}, |
{"name":"chunk_overlap","type":"int","value":20, "min":0, "max":personality.config["ctx_size"],"help":"Overlap between chunks"}, |
{"name":"max_answer_size","type":"int","value":512, "min":10, "max":personality.config["ctx_size"],"help":"Maximum number of tokens to allow the generator to generate as an answer to your question"}, |
{"name":"visualize_data_at_startup","type":"bool","value":False, "help":"If true, the database will be visualized at startup"}, |
{"name":"visualize_data_at_add_file","type":"bool","value":False, "help":"If true, the database will be visualized when a new file is added"}, |
{"name":"visualize_data_at_generate","type":"bool","value":False, "help":"If true, the database will be visualized at generation time"}, |
] |
) |
personality_config_vals = BaseConfig.from_template(personality_config_template) |
personality_config = TypedConfig( |
personality_config_template, |
personality_config_vals |
) |
super().__init__( |
personality, |
personality_config |
) |
self.state = 0 |
self.ready = False |
self.personality = personality |
self.callback = None |
self.vector_store = TextVectorizer( |
"bert-base-uncased", |
self.personality.lollms_paths.personal_data_path/self.personality_config["database_path"], |
visualize_data_at_startup=self.personality_config["visualize_data_at_startup"], |
visualize_data_at_add_file=self.personality_config["visualize_data_at_add_file"], |
visualize_data_at_generate=self.personality_config["visualize_data_at_generate"] |
) |
if len(self.vector_store.embeddings)>0: |
self.ready = True |
@staticmethod |
def read_pdf_file(file_path): |
import PyPDF2 |
with open(file_path, 'rb') as file: |
pdf_reader = PyPDF2.PdfReader(file) |
text = "" |
for page in pdf_reader.pages: |
text += page.extract_text() |
return text |
@staticmethod |
def read_docx_file(file_path): |
from docx import Document |
doc = Document(file_path) |
text = "" |
for paragraph in doc.paragraphs: |
text += paragraph.text + "\n" |
return text |
@staticmethod |
def read_json_file(file_path): |
with open(file_path, 'r') as file: |
data = json.load(file) |
return data |
@staticmethod |
def read_csv_file(file_path): |
import csv |
with open(file_path, 'r') as file: |
csv_reader = csv.reader(file) |
lines = [row for row in csv_reader] |
return lines |
@staticmethod |
def read_html_file(file_path): |
from bs4 import BeautifulSoup |
with open(file_path, 'r') as file: |
soup = BeautifulSoup(file, 'html.parser') |
text = soup.get_text() |
return text |
@staticmethod |
def read_pptx_file(file_path): |
from pptx import Presentation |
prs = Presentation(file_path) |
text = "" |
for slide in prs.slides: |
for shape in slide.shapes: |
if shape.has_text_frame: |
for paragraph in shape.text_frame.paragraphs: |
for run in paragraph.runs: |
text += run.text |
return text |
@staticmethod |
def read_text_file(file_path): |
with open(file_path, 'r', encoding='utf-8') as file: |
content = file.read() |
return content |
def build_db(self): |
ASCIIColors.info("-> Vectorizing the database"+ASCIIColors.color_orange) |
if self.callback is not None: |
self.callback("Vectorizing the database", MSG_TYPE.MSG_TYPE_CHUNK) |
for file in self.files: |
try: |
if Path(file).suffix==".pdf": |
text = Processor.read_pdf_file(file) |
elif Path(file).suffix==".docx": |
text = Processor.read_docx_file(file) |
elif Path(file).suffix==".docx": |
text = Processor.read_pptx_file(file) |
elif Path(file).suffix==".json": |
text = Processor.read_json_file(file) |
elif Path(file).suffix==".csv": |
text = Processor.read_csv_file(file) |
elif Path(file).suffix==".html": |
text = Processor.read_html_file(file) |
else: |
text = Processor.read_text_file(file) |
try: |
chunk_size=int(self.personality_config["chunk_size"]) |
except: |
ASCIIColors.warning(f"Couldn't read chunk size. Verify your configuration file") |
chunk_size=512 |
try: |
overlap_size=int(self.personality_config["chunk_overlap"]) |
except: |
ASCIIColors.warning(f"Couldn't read chunk size. Verify your configuration file") |
overlap_size=50 |
self.vector_store.index_document(file, text, chunk_size=chunk_size, overlap_size=overlap_size) |
print(ASCIIColors.color_reset) |
ASCIIColors.success(f"File {file} vectorized successfully") |
self.ready = True |
except Exception as ex: |
ASCIIColors.error(f"Couldn't vectorize {file}: The vectorizer threw this exception:{ex}") |
def add_file(self, path): |
super().add_file(path) |
try: |
self.build_db() |
self.ready = True |
return True |
except Exception as ex: |
ASCIIColors.error(f"Couldn't vectorize the database: The vectgorizer threw this exception: {ex}") |
return False |
def run_workflow(self, prompt, previous_discussion_text="", callback=None): |
""" |
Runs the workflow for processing the model input and output. |
This method should be called to execute the processing workflow. |
Args: |
generate_fn (function): A function that generates model output based on the input prompt. |
The function should take a single argument (prompt) and return the generated text. |
prompt (str): The input prompt for the model. |
previous_discussion_text (str, optional): The text of the previous discussion. Default is an empty string. |
Returns: |
None |
""" |
output ="" |
self.callback = callback |
if prompt.strip().lower()=="send_file": |
self.state = 1 |
print("Please provide the file name") |
if callback is not None: |
callback("Please provide the file path", MSG_TYPE.MSG_TYPE_FULL) |
output = "Please provide the file name" |
elif prompt.strip().lower()=="help": |
if callback: |
callback(self.personality.help,MSG_TYPE.MSG_TYPE_FULL) |
ASCIIColors.info(help) |
self.state = 0 |
elif prompt.strip().lower()=="show_database": |
try: |
self.vector_store.show_document() |
except Exception as ex: |
if callback is not None: |
callback(f"Couldn't show the database\nMake sure you have already uploaded a database.\nReceived exception is: {ex}", MSG_TYPE.MSG_TYPE_FULL) |
self.state = 0 |
elif prompt.strip().lower()=="set_database": |
print("Please provide the database file name") |
if callback is not None: |
callback("Please provide the database file path", MSG_TYPE.MSG_TYPE_FULL) |
output = "Please provide the database file name" |
self.state = 2 |
elif prompt.strip().lower()=="clear_database": |
database_fill_path:Path = self.personality.lollms_paths.personal_data_path/self.personality_config["database_path"] |
if database_fill_path.exists(): |
database_fill_path.unlink() |
self.vector_store = TextVectorizer( |
"bert-base-uncased", |
self.personality.lollms_paths.personal_data_path/self.personality_config["database_path"], |
visualize_data_at_startup=self.personality_config["visualize_data_at_startup"], |
visualize_data_at_add_file=self.personality_config["visualize_data_at_add_file"], |
visualize_data_at_generate=self.personality_config["visualize_data_at_generate"] |
) |
if callback is not None: |
callback("Database file cleared successfully", MSG_TYPE.MSG_TYPE_FULL) |
else: |
if callback is not None: |
callback("The database file does not exist yet, so you can't clear it", MSG_TYPE.MSG_TYPE_FULL) |
self.state = 0 |
else: |
if self.state ==1: |
try: |
self.add_file(prompt) |
if callback is not None: |
callback(f"File {prompt} added successfully", MSG_TYPE.MSG_TYPE_FULL) |
except Exception as ex: |
ASCIIColors.error(f"Exception: {ex}") |
if callback is not None: |
callback(f"Couldn't load file {prompt}.\nThe following exception was thrown: {ex}", MSG_TYPE.MSG_TYPE_FULL) |
output = str(ex) |
self.state=0 |
elif self.state ==2: |
try: |
new_db_path = Path(prompt) |
if new_db_path.exists(): |
self.personality_config["database_path"] = prompt |
self.personality_config.save() |
self.vector_store = TextVectorizer( |
"bert-base-uncased", |
self.personality.lollms_paths.personal_data_path/self.personality_config["database_path"], |
visualize_data_at_startup=self.personality_config["visualize_data_at_startup"], |
visualize_data_at_add_file=self.personality_config["visualize_data_at_add_file"], |
visualize_data_at_generate=self.personality_config["visualize_data_at_generate"] |
) |
self.save_config_file(self.personality.lollms_paths.personal_configuration_path/f"personality_{self.personality.name}.yaml", self.personality_config) |
else: |
output = "Database file not found.\nGoing back to default state." |
except Exception as ex: |
ASCIIColors.error(f"Exception: {ex}") |
output = str(ex) |
self.state=0 |
else: |
if not self.ready: |
ASCIIColors.error(f"No data to discuss. Please upload a document first") |
else: |
docs = self.vector_store.recover_text(self.vector_store.embed_query(prompt), top_k=3) |
docs = '\n'.join([f"Doc{i}:\n{v}" for i,v in enumerate(docs)]) |
full_text = self.personality.personality_conditioning+"\n### Docs:\n"+docs+"\n### Question: "+prompt+"\n### Answer:" |
ASCIIColors.blue("-------------- Documentation -----------------------") |
ASCIIColors.blue(full_text) |
ASCIIColors.blue("----------------------------------------------------") |
ASCIIColors.blue("Thinking") |
if callback is not None: |
callback("Thinking", MSG_TYPE.MSG_TYPE_FULL) |
output = self.generate(full_text, self.personality_config["max_answer_size"]) |
if callback is not None: |
callback(output, MSG_TYPE.MSG_TYPE_FULL) |
return output |