Spaces:
Running
Running
import hashlib | |
from datetime import datetime | |
from multiprocessing.pool import ThreadPool | |
from typing import List | |
import requests | |
from clickhouse_sqlalchemy import types, engines | |
from langchain.schema.embeddings import Embeddings | |
from sqlalchemy import Column, Text | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
def parse_files(api_key, user_id, files: List[UploadedFile]): | |
def parse_file(file: UploadedFile): | |
headers = { | |
"accept": "application/json", | |
"unstructured-api-key": api_key, | |
} | |
data = {"strategy": "auto", "ocr_languages": ["eng"]} | |
file_hash = hashlib.sha256(file.read()).hexdigest() | |
file_data = {"files": (file.name, file.getvalue(), file.type)} | |
response = requests.post( | |
url="https://api.unstructured.io/general/v0/general", | |
headers=headers, | |
data=data, | |
files=file_data | |
) | |
json_response = response.json() | |
if response.status_code != 200: | |
raise ValueError(str(json_response)) | |
texts = [ | |
{ | |
"text": t["text"], | |
"file_name": t["metadata"]["filename"], | |
"entity_id": hashlib.sha256( | |
(file_hash + t["text"]).encode() | |
).hexdigest(), | |
"user_id": user_id, | |
"created_by": datetime.now(), | |
} | |
for t in json_response | |
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 | |
] | |
return texts | |
with ThreadPool(8) as p: | |
rows = [] | |
for r in p.imap_unordered(parse_file, files): | |
rows.extend(r) | |
return rows | |
def extract_embedding(embeddings: Embeddings, texts): | |
if len(texts) > 0: | |
embeddings = embeddings.embed_documents( | |
[t["text"] for _, t in enumerate(texts)]) | |
for i, _ in enumerate(texts): | |
texts[i]["vector"] = embeddings[i] | |
return texts | |
raise ValueError("No texts extracted!") | |
def create_message_history_table(table_name: str, base_class): | |
class Message(base_class): | |
__tablename__ = table_name | |
id = Column(types.Float64) | |
session_id = Column(Text) | |
user_id = Column(Text) | |
msg_id = Column(Text, primary_key=True) | |
type = Column(Text) | |
# should be additions, formal developer mistake spell it. | |
addtionals = Column(Text) | |
message = Column(Text) | |
__table_args__ = ( | |
engines.MergeTree( | |
partition_by='session_id', | |
order_by=('id', 'msg_id') | |
), | |
{'comment': 'Store Chat History'} | |
) | |
return Message | |
def create_session_table(table_name: str, DynamicBase): | |
class Session(DynamicBase): | |
__tablename__ = table_name | |
user_id = Column(Text) | |
session_id = Column(Text, primary_key=True) | |
system_prompt = Column(Text) | |
# represent create time. | |
create_by = Column(types.DateTime) | |
# should be additions, formal developer mistake spell it. | |
additionals = Column(Text) | |
__table_args__ = ( | |
engines.MergeTree(order_by=session_id), | |
{'comment': 'Store Session and Prompts'} | |
) | |
return Session | |