File size: 3,877 Bytes
5481095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import os

from langchain.text_splitter import RecursiveCharacterTextSplitter
from pymilvus import MilvusClient, db, utility, Collection, CollectionSchema, FieldSchema, DataType
from sentence_transformers import SentenceTransformer
import torch
from .milvus_singleton import MilvusClientSingleton


os.environ['HF_HOME'] = '/app/cache'
os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', 
                                      trust_remote_code=True, 
                                      device='cuda' if torch.cuda.is_available() else 'cpu',
                                      cache_folder='/app/cache'
)
collection_name="rag"

def setup_milvus():
    global milvus_client
    milvus_client = MilvusClientSingleton.get_instance(uri="/app/milvus_data/milvus_demo.db")

def document_to_embeddings(content:str) -> list:
    return embedding_model.encode(content, show_progress_bar=True)

setup_milvus()

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Replace with the list of allowed origins for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def split_documents(document_data):
    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10)
    return splitter.split_documents(document_data)

def create_a_collection(milvus_client, collection_name):
    content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
    vector = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)

    schema = CollectionSchema([
        content, vector
    ])

    vector_index = {
        "index_type": "IVF_FLAT",
        "metric_type": "COSINE",
        "params": {
            "nlist": 128
        }
    }
    
    milvus_client.create_collection(
        collection_name=collection_name,
        schema=schema,
        index_params=vector_index,
    )

@app.get("/")
async def root():
    return {"message": "Hello World"}

@app.post("/insert")
async def insert(file: UploadFile = File(...)):
    contents = await file.read()
    
    if not milvus_client.has_collection(collection_name):
        create_a_collection(milvus_client, collection_name)

    splitted_document_data = split_documents(contents)

    data_objects = []
    for doc in splitted_document_data:
        data = {
            "vector": document_to_embeddings(doc.page_content),
            "content": doc.page_content,
        }
        data_objects.append(data)

    try:
        milvus_client.insert(collection_name=collection_name, data=data_objects)

    except Exception as e:
        raise JSONResponse(status_code=500, content={"error": str(e)})
    else:
        return JSONResponse(status_code=200, content={"result": 'good'})
    
@app.post("/rag")
async def insert(question):
    if not question:
        return JSONResponse(status_code=400, content={"message": "Please a question!"})
    
    try:
        search_res = milvus_client.search(
            collection_name=collection_name,
            data=[
                document_to_embeddings(question)
            ], 
            limit=5,  # Return top 3 results
            search_params={"metric_type": "COSINE"},  # Inner product distance
            output_fields=["content"],  # Return the text field
        )

        retrieved_lines_with_distances = [
            (res["entity"]["content"]) for res in search_res[0]
        ]
        return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances[0]})
    except Exception as e:
        return JSONResponse(status_code=400, content={"error": str(e)})