lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
import hashlib
from datetime import datetime
from multiprocessing.pool import ThreadPool
from typing import List
import requests
from clickhouse_sqlalchemy import types, engines
from langchain.schema.embeddings import Embeddings
from sqlalchemy import Column, Text
from streamlit.runtime.uploaded_file_manager import UploadedFile
def parse_files(api_key, user_id, files: List[UploadedFile]):
def parse_file(file: UploadedFile):
headers = {
"accept": "application/json",
"unstructured-api-key": api_key,
}
data = {"strategy": "auto", "ocr_languages": ["eng"]}
file_hash = hashlib.sha256(file.read()).hexdigest()
file_data = {"files": (file.name, file.getvalue(), file.type)}
response = requests.post(
url="https://api.unstructured.io/general/v0/general",
headers=headers,
data=data,
files=file_data
)
json_response = response.json()
if response.status_code != 200:
raise ValueError(str(json_response))
texts = [
{
"text": t["text"],
"file_name": t["metadata"]["filename"],
"entity_id": hashlib.sha256(
(file_hash + t["text"]).encode()
).hexdigest(),
"user_id": user_id,
"created_by": datetime.now(),
}
for t in json_response
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
]
return texts
with ThreadPool(8) as p:
rows = []
for r in p.imap_unordered(parse_file, files):
rows.extend(r)
return rows
def extract_embedding(embeddings: Embeddings, texts):
if len(texts) > 0:
embeddings = embeddings.embed_documents(
[t["text"] for _, t in enumerate(texts)])
for i, _ in enumerate(texts):
texts[i]["vector"] = embeddings[i]
return texts
raise ValueError("No texts extracted!")
def create_message_history_table(table_name: str, base_class):
class Message(base_class):
__tablename__ = table_name
id = Column(types.Float64)
session_id = Column(Text)
user_id = Column(Text)
msg_id = Column(Text, primary_key=True)
type = Column(Text)
# should be additions, formal developer mistake spell it.
addtionals = Column(Text)
message = Column(Text)
__table_args__ = (
engines.MergeTree(
partition_by='session_id',
order_by=('id', 'msg_id')
),
{'comment': 'Store Chat History'}
)
return Message
def create_session_table(table_name: str, DynamicBase):
class Session(DynamicBase):
__tablename__ = table_name
user_id = Column(Text)
session_id = Column(Text, primary_key=True)
system_prompt = Column(Text)
# represent create time.
create_by = Column(types.DateTime)
# should be additions, formal developer mistake spell it.
additionals = Column(Text)
__table_args__ = (
engines.MergeTree(order_by=session_id),
{'comment': 'Store Session and Prompts'}
)
return Session