docx / streamlit_app.py
RishuD7's picture
added application startup command
87ad181
raw
history blame
2.92 kB
import os
import streamlit as st
import requests
from typing import List
import json
import socket
from urllib3.connection import HTTPConnection
from app import embed_documents, retrieve_documents
# API_BASE_URL = os.environ.get("API_BASE_URL")
embeddings_model_name = "all-MiniLM-L6-v2"
persist_directory = "db"
model = "tiiuae/falcon-7b-instruct"
from constants import CHROMA_SETTINGS
import chromadb
def list_of_collections():
client = chromadb.Client(CHROMA_SETTINGS)
return (client.list_collections())
def main():
st.title("PrivateGPT App: Document Embedding and Retrieval")
# Document upload section
st.header("Document Upload")
files = st.file_uploader("Upload document", accept_multiple_files=True)
# collection_name = st.text_input("Collection Name") not working for some reason
if st.button("Embed"):
embed_documents(files, "collection_name")
# Query section
st.header("Document Retrieval")
collection_names = get_collection_names()
selected_collection = st.selectbox("Select a document", collection_names)
if selected_collection:
query = st.text_input("Query")
if st.button("Retrieve"):
retrieve_documents(query, selected_collection)
# def embed_documents(files:List[st.runtime.uploaded_file_manager.UploadedFile], collection_name:str):
# endpoint = f"{API_BASE_URL}/embed"
# files_data = [("files", file) for file in files]
# data = {"collection_name": collection_name}
# response = requests.post(endpoint, files=files_data, data=data)
# if response.status_code == 200:
# st.success("Documents embedded successfully!")
# else:
# st.error("Document embedding failed.")
# st.write(response.text)
def get_collection_names():
collections = list_of_collections()
return [collection.name for collection in collections]
# def retrieve_documents(query: str, collection_name: str):
# endpoint = f"{API_BASE_URL}/retrieve"
# data = {"query": query, "collection_name": collection_name}
# # Modify socket options for the HTTPConnection class
# HTTPConnection.default_socket_options = (
# HTTPConnection.default_socket_options + [
# (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
# (socket.SOL_TCP, socket.TCP_KEEPIDLE, 45),
# (socket.SOL_TCP, socket.TCP_KEEPINTVL, 10),
# (socket.SOL_TCP, socket.TCP_KEEPCNT, 6)
# ]
# )
# response = requests.post(endpoint, params=data)
# if response.status_code == 200:
# result = response.json()
# st.subheader("Results")
# st.text(result["results"])
# st.subheader("Documents")
# for doc in result["docs"]:
# st.text(doc)
# else:
# st.error("Failed to retrieve documents.")
# st.write(response.text)
if __name__ == '__main__':
main()