Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -8
- __init__.py +9 -0
- __main__.py +7 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/__main__.cpython-310.pyc +0 -0
- __pycache__/_config.cpython-310.pyc +0 -0
- __pycache__/enums.cpython-310.pyc +0 -0
- __pycache__/main.cpython-310.pyc +0 -0
- __pycache__/paths.cpython-310.pyc +0 -0
- _config.py +62 -0
- components/__init__.py +0 -0
- components/__pycache__/__init__.cpython-310.pyc +0 -0
- components/embedding/__init__.py +0 -0
- components/embedding/__pycache__/__init__.cpython-310.pyc +0 -0
- components/embedding/__pycache__/component.cpython-310.pyc +0 -0
- components/embedding/component.py +38 -0
- components/ingest/__init__.py +0 -0
- components/ingest/__pycache__/__init__.cpython-310.pyc +0 -0
- components/ingest/__pycache__/component.cpython-310.pyc +0 -0
- components/ingest/__pycache__/helpers.cpython-310.pyc +0 -0
- components/ingest/component.py +143 -0
- components/ingest/helpers.py +61 -0
- components/llm/__init__.py +0 -0
- components/llm/__pycache__/__init__.cpython-310.pyc +0 -0
- components/llm/__pycache__/component.cpython-310.pyc +0 -0
- components/llm/component.py +50 -0
- components/node_store/__init__.py +0 -0
- components/node_store/__pycache__/__init__.cpython-310.pyc +0 -0
- components/node_store/__pycache__/component.cpython-310.pyc +0 -0
- components/node_store/component.py +31 -0
- components/vector_store/__init__.py +0 -0
- components/vector_store/__pycache__/__init__.cpython-310.pyc +0 -0
- components/vector_store/__pycache__/component.cpython-310.pyc +0 -0
- components/vector_store/component.py +51 -0
- enums.py +39 -0
- main.py +38 -0
- paths.py +15 -0
- server/__init__.py +0 -0
- server/__pycache__/__init__.cpython-310.pyc +0 -0
- server/chat/__init__.py +0 -0
- server/chat/__pycache__/__init__.cpython-310.pyc +0 -0
- server/chat/__pycache__/router.cpython-310.pyc +0 -0
- server/chat/__pycache__/schemas.cpython-310.pyc +0 -0
- server/chat/__pycache__/service.cpython-310.pyc +0 -0
- server/chat/__pycache__/utils.cpython-310.pyc +0 -0
- server/chat/router.py +70 -0
- server/chat/schemas.py +45 -0
- server/chat/service.py +122 -0
- server/chat/utils.py +68 -0
- server/embedding/__init__.py +0 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.33.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: discord-bot
|
3 |
+
app_file: __main__.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.33.0
|
|
|
|
|
6 |
---
|
|
|
|
__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
ROOT_LOG_LEVEL = "INFO"
|
4 |
+
|
5 |
+
PRETTY_LOG_FORMAT = (
|
6 |
+
"%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
|
7 |
+
)
|
8 |
+
logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
|
9 |
+
logging.captureWarnings(True)
|
__main__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
|
3 |
+
from app._config import settings
|
4 |
+
from app.main import app
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)
|
__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (378 Bytes). View file
|
|
__pycache__/__main__.cpython-310.pyc
ADDED
Binary file (327 Bytes). View file
|
|
__pycache__/_config.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
__pycache__/enums.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
__pycache__/main.cpython-310.pyc
ADDED
Binary file (1.24 kB). View file
|
|
__pycache__/paths.cpython-310.pyc
ADDED
Binary file (643 Bytes). View file
|
|
_config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Literal, Optional
|
3 |
+
|
4 |
+
from pydantic import Field
|
5 |
+
from pydantic_settings import BaseSettings
|
6 |
+
|
7 |
+
|
8 |
+
class Settings(BaseSettings):
|
9 |
+
ENVIRONMENT: str
|
10 |
+
PORT: int = 8000
|
11 |
+
VECTOR_DATABASE: Literal["weaviate"] = "weaviate"
|
12 |
+
|
13 |
+
OPENAI_API_KEY: Optional[str] = None
|
14 |
+
OPENAI_MODEL: str = "gpt-3.5-turbo"
|
15 |
+
|
16 |
+
WEAVIATE_CLIENT_URL: str = "http://localhost:8080"
|
17 |
+
|
18 |
+
LLM_MODE: Literal["openai", "mock", "local"] = "mock"
|
19 |
+
EMBEDDING_MODE: Literal["openai", "mock", "local"] = "mock"
|
20 |
+
|
21 |
+
LOCAL_DATA_FOLDER: str = "local_data/test"
|
22 |
+
|
23 |
+
DEFAULT_QUERY_SYSTEM_PROMPT: str = "You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
24 |
+
|
25 |
+
LOCAL_HF_EMBEDDING_MODEL_NAME: str = "BAAI/bge-small-en-v1.5"
|
26 |
+
|
27 |
+
LOCAL_HF_LLM_REPO_ID: str = "TheBloke/Llama-2-7B-Chat-GGUF"
|
28 |
+
LOCAL_HF_LLM_MODEL_FILE: str = "llama-2-7b-chat.Q4_K_M.gguf"
|
29 |
+
|
30 |
+
# LLM config
|
31 |
+
LLM_TEMPERATURE: float = Field(
|
32 |
+
default=0.1, description="The temperature to use for sampling."
|
33 |
+
)
|
34 |
+
LLM_MAX_NEW_TOKENS: int = Field(
|
35 |
+
default=256,
|
36 |
+
description="The maximum number of tokens to generate.",
|
37 |
+
)
|
38 |
+
LLM_CONTEXT_WINDOW: int = Field(
|
39 |
+
default=3900,
|
40 |
+
description="The maximum number of context tokens for the model.",
|
41 |
+
)
|
42 |
+
|
43 |
+
# UI
|
44 |
+
IS_UI_ENABLED: bool = True
|
45 |
+
UI_PATH: str = "/"
|
46 |
+
|
47 |
+
# Rerank
|
48 |
+
IS_RERANK_ENABLED: bool = True
|
49 |
+
RERANK_TOP_N: int = 3
|
50 |
+
RERANK_MODEL_NAME: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"
|
51 |
+
|
52 |
+
class Config:
|
53 |
+
case_sensitive = True
|
54 |
+
env_file_encoding = "utf-8"
|
55 |
+
|
56 |
+
|
57 |
+
environment = os.environ.get("ENVIRONMENT", "local")
|
58 |
+
settings = Settings(
|
59 |
+
ENVIRONMENT=environment,
|
60 |
+
# ".env.{environment}" takes priority over ".env"
|
61 |
+
_env_file=[".env", f".env.{environment}"],
|
62 |
+
)
|
components/__init__.py
ADDED
File without changes
|
components/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
components/embedding/__init__.py
ADDED
File without changes
|
components/embedding/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (159 Bytes). View file
|
|
components/embedding/__pycache__/component.cpython-310.pyc
ADDED
Binary file (1.35 kB). View file
|
|
components/embedding/component.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index import MockEmbedding
|
4 |
+
from llama_index.embeddings.base import BaseEmbedding
|
5 |
+
|
6 |
+
from app._config import settings
|
7 |
+
from app.enums import EmbeddingMode
|
8 |
+
from app.paths import models_cache_path
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
MOCK_EMBEDDING_DIM = 1536
|
13 |
+
|
14 |
+
|
15 |
+
class EmbeddingComponent:
|
16 |
+
embedding_model: BaseEmbedding
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
embedding_mode = settings.EMBEDDING_MODE
|
20 |
+
logger.info("Initializing the embedding model in mode=%s", embedding_mode)
|
21 |
+
match embedding_mode:
|
22 |
+
case EmbeddingMode.OPENAI:
|
23 |
+
from llama_index import OpenAIEmbedding
|
24 |
+
|
25 |
+
self.embedding_model = OpenAIEmbedding(api_key=settings.OPENAI_API_KEY)
|
26 |
+
|
27 |
+
case EmbeddingMode.MOCK:
|
28 |
+
# Not a random number, is the dimensionality used by
|
29 |
+
# the default embedding model
|
30 |
+
self.embedding_model = MockEmbedding(MOCK_EMBEDDING_DIM)
|
31 |
+
|
32 |
+
case EmbeddingMode.LOCAL:
|
33 |
+
from llama_index.embeddings import HuggingFaceEmbedding
|
34 |
+
|
35 |
+
self.embedding_model = HuggingFaceEmbedding(
|
36 |
+
model_name=settings.LOCAL_HF_EMBEDDING_MODEL_NAME,
|
37 |
+
cache_folder=str(models_cache_path),
|
38 |
+
)
|
components/ingest/__init__.py
ADDED
File without changes
|
components/ingest/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (156 Bytes). View file
|
|
components/ingest/__pycache__/component.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
components/ingest/__pycache__/helpers.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
components/ingest/component.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from llama_index import (
|
8 |
+
Document,
|
9 |
+
ServiceContext,
|
10 |
+
StorageContext,
|
11 |
+
VectorStoreIndex,
|
12 |
+
load_index_from_storage,
|
13 |
+
)
|
14 |
+
from llama_index.data_structs import IndexDict
|
15 |
+
from llama_index.indices.base import BaseIndex
|
16 |
+
|
17 |
+
from app.components.ingest.helpers import IngestionHelper
|
18 |
+
from app.paths import local_data_path
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class BaseIngestComponent(abc.ABC):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
storage_context: StorageContext,
|
27 |
+
service_context: ServiceContext,
|
28 |
+
*args: Any,
|
29 |
+
**kwargs: Any,
|
30 |
+
) -> None:
|
31 |
+
logger.debug(f"Initializing base ingest component type={type(self).__name__}")
|
32 |
+
self.storage_context = storage_context
|
33 |
+
self.service_context = service_context
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
37 |
+
pass
|
38 |
+
|
39 |
+
@abc.abstractmethod
|
40 |
+
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abc.abstractmethod
|
44 |
+
def delete(self, doc_id: str) -> None:
|
45 |
+
pass
|
46 |
+
|
47 |
+
|
48 |
+
class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
storage_context: StorageContext,
|
52 |
+
service_context: ServiceContext,
|
53 |
+
*args: Any,
|
54 |
+
**kwargs: Any,
|
55 |
+
) -> None:
|
56 |
+
super().__init__(storage_context, service_context, *args, **kwargs)
|
57 |
+
|
58 |
+
self.show_progress = True
|
59 |
+
self._index_thread_lock = (
|
60 |
+
threading.Lock()
|
61 |
+
) # Thread lock! Not Multiprocessing lock
|
62 |
+
self._index = self._initialize_index()
|
63 |
+
|
64 |
+
def _initialize_index(self) -> BaseIndex[IndexDict]:
|
65 |
+
"""Initialize the index from the storage context."""
|
66 |
+
try:
|
67 |
+
# Load the index with store_nodes_override=True to be able to delete them
|
68 |
+
index = load_index_from_storage(
|
69 |
+
storage_context=self.storage_context,
|
70 |
+
service_context=self.service_context,
|
71 |
+
store_nodes_override=True, # Force store nodes in index and document stores
|
72 |
+
show_progress=self.show_progress,
|
73 |
+
)
|
74 |
+
except ValueError:
|
75 |
+
# There are no index in the storage context, creating a new one
|
76 |
+
logger.info("Creating a new vector store index")
|
77 |
+
index = VectorStoreIndex.from_documents(
|
78 |
+
[],
|
79 |
+
storage_context=self.storage_context,
|
80 |
+
service_context=self.service_context,
|
81 |
+
store_nodes_override=True, # Force store nodes in index and document stores
|
82 |
+
show_progress=self.show_progress,
|
83 |
+
)
|
84 |
+
index.storage_context.persist(persist_dir=local_data_path)
|
85 |
+
return index
|
86 |
+
|
87 |
+
def _save_index(self) -> None:
|
88 |
+
self._index.storage_context.persist(persist_dir=local_data_path)
|
89 |
+
|
90 |
+
def delete(self, doc_id: str) -> None:
|
91 |
+
with self._index_thread_lock:
|
92 |
+
# Delete the document from the index
|
93 |
+
self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
94 |
+
|
95 |
+
# Save the index
|
96 |
+
self._save_index()
|
97 |
+
|
98 |
+
|
99 |
+
class SimpleIngestComponent(BaseIngestComponentWithIndex):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
storage_context: StorageContext,
|
103 |
+
service_context: ServiceContext,
|
104 |
+
*args: Any,
|
105 |
+
**kwargs: Any,
|
106 |
+
) -> None:
|
107 |
+
super().__init__(storage_context, service_context, *args, **kwargs)
|
108 |
+
|
109 |
+
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
110 |
+
logger.info("Ingesting file_name=%s", file_name)
|
111 |
+
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
112 |
+
logger.info(
|
113 |
+
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
114 |
+
)
|
115 |
+
logger.debug("Saving the documents in the index and doc store")
|
116 |
+
return self._save_docs(documents)
|
117 |
+
|
118 |
+
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
119 |
+
saved_documents = []
|
120 |
+
for file_name, file_data in files:
|
121 |
+
documents = IngestionHelper.transform_file_into_documents(
|
122 |
+
file_name, file_data
|
123 |
+
)
|
124 |
+
saved_documents.extend(self._save_docs(documents))
|
125 |
+
return saved_documents
|
126 |
+
|
127 |
+
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
128 |
+
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
129 |
+
with self._index_thread_lock:
|
130 |
+
for document in documents:
|
131 |
+
self._index.insert(document, show_progress=True)
|
132 |
+
logger.debug("Persisting the index and nodes")
|
133 |
+
# persist the index and nodes
|
134 |
+
self._save_index()
|
135 |
+
logger.debug("Persisted the index and nodes")
|
136 |
+
return documents
|
137 |
+
|
138 |
+
|
139 |
+
def get_ingestion_component(
|
140 |
+
storage_context: StorageContext,
|
141 |
+
service_context: ServiceContext,
|
142 |
+
) -> BaseIngestComponent:
|
143 |
+
return SimpleIngestComponent(storage_context, service_context)
|
components/ingest/helpers.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from llama_index import Document
|
5 |
+
from llama_index.readers import JSONReader, StringIterableReader
|
6 |
+
from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# Patching the default file reader to support other file types
|
11 |
+
FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
|
12 |
+
FILE_READER_CLS.update(
|
13 |
+
{
|
14 |
+
".json": JSONReader,
|
15 |
+
}
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class IngestionHelper:
|
20 |
+
"""Helper class to transform a file into a list of documents.
|
21 |
+
|
22 |
+
This class should be used to transform a file into a list of documents.
|
23 |
+
These methods are thread-safe (and multiprocessing-safe).
|
24 |
+
"""
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def transform_file_into_documents(
|
28 |
+
file_name: str, file_data: Path
|
29 |
+
) -> list[Document]:
|
30 |
+
documents = IngestionHelper._load_file_to_documents(file_name, file_data)
|
31 |
+
for document in documents:
|
32 |
+
document.metadata["file_name"] = file_name
|
33 |
+
IngestionHelper._exclude_metadata(documents)
|
34 |
+
return documents
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
|
38 |
+
logger.debug("Transforming file_name=%s into documents", file_name)
|
39 |
+
extension = Path(file_name).suffix
|
40 |
+
reader_cls = FILE_READER_CLS.get(extension)
|
41 |
+
if reader_cls is None:
|
42 |
+
logger.debug(
|
43 |
+
"No reader found for extension=%s, using default string reader",
|
44 |
+
extension,
|
45 |
+
)
|
46 |
+
# Read as a plain text
|
47 |
+
string_reader = StringIterableReader()
|
48 |
+
return string_reader.load_data([file_data.read_text()])
|
49 |
+
|
50 |
+
logger.debug("Specific reader found for extension=%s", extension)
|
51 |
+
return reader_cls().load_data(file_data)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _exclude_metadata(documents: list[Document]) -> None:
|
55 |
+
logger.debug("Excluding metadata from count=%s documents", len(documents))
|
56 |
+
for document in documents:
|
57 |
+
document.metadata["doc_id"] = document.doc_id
|
58 |
+
# We don't want the Embeddings search to receive this metadata
|
59 |
+
document.excluded_embed_metadata_keys = ["doc_id"]
|
60 |
+
# We don't want the LLM to receive these metadata in the context
|
61 |
+
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
|
components/llm/__init__.py
ADDED
File without changes
|
components/llm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (153 Bytes). View file
|
|
components/llm/__pycache__/component.cpython-310.pyc
ADDED
Binary file (1.49 kB). View file
|
|
components/llm/component.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index.llms import LLM, MockLLM
|
4 |
+
|
5 |
+
from app._config import settings
|
6 |
+
from app.enums import LLMMode
|
7 |
+
from app.paths import models_path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class LLMComponent:
|
13 |
+
llm: LLM
|
14 |
+
|
15 |
+
def __init__(self) -> None:
|
16 |
+
llm_mode = settings.LLM_MODE
|
17 |
+
logger.info(f"Initializing the LLM in mode={llm_mode}")
|
18 |
+
match settings.LLM_MODE:
|
19 |
+
case LLMMode.OPENAI:
|
20 |
+
from llama_index.llms import OpenAI
|
21 |
+
|
22 |
+
self.llm = OpenAI(
|
23 |
+
api_key=settings.OPENAI_API_KEY,
|
24 |
+
model=settings.OPENAI_MODEL,
|
25 |
+
)
|
26 |
+
case LLMMode.MOCK:
|
27 |
+
self.llm = MockLLM()
|
28 |
+
|
29 |
+
case LLMMode.LOCAL:
|
30 |
+
from llama_index.llms import LlamaCPP
|
31 |
+
from llama_index.llms.llama_utils import (
|
32 |
+
completion_to_prompt,
|
33 |
+
messages_to_prompt,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.llm = LlamaCPP(
|
37 |
+
model_path=str(models_path / settings.LOCAL_HF_LLM_MODEL_FILE),
|
38 |
+
temperature=settings.LLM_TEMPERATURE,
|
39 |
+
max_new_tokens=settings.LLM_MAX_NEW_TOKENS,
|
40 |
+
context_window=settings.LLM_CONTEXT_WINDOW,
|
41 |
+
generate_kwargs={},
|
42 |
+
# set to at least 1 to use GPU
|
43 |
+
# set to -1 for all gpu
|
44 |
+
# set to 0 for cpu
|
45 |
+
model_kwargs={"n_gpu_layers": 0},
|
46 |
+
# transform inputs into Llama2 format
|
47 |
+
messages_to_prompt=messages_to_prompt,
|
48 |
+
completion_to_prompt=completion_to_prompt,
|
49 |
+
verbose=True,
|
50 |
+
)
|
components/node_store/__init__.py
ADDED
File without changes
|
components/node_store/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
components/node_store/__pycache__/component.cpython-310.pyc
ADDED
Binary file (1.24 kB). View file
|
|
components/node_store/component.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index.storage.docstore import BaseDocumentStore, SimpleDocumentStore
|
4 |
+
from llama_index.storage.index_store import SimpleIndexStore
|
5 |
+
from llama_index.storage.index_store.types import BaseIndexStore
|
6 |
+
|
7 |
+
from app.paths import local_data_path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class NodeStoreComponent:
|
13 |
+
index_store: BaseIndexStore
|
14 |
+
doc_store: BaseDocumentStore
|
15 |
+
|
16 |
+
def __init__(self) -> None:
|
17 |
+
try:
|
18 |
+
self.index_store = SimpleIndexStore.from_persist_dir(
|
19 |
+
persist_dir=str(local_data_path)
|
20 |
+
)
|
21 |
+
except FileNotFoundError:
|
22 |
+
logger.debug("Local index store not found, creating a new one")
|
23 |
+
self.index_store = SimpleIndexStore()
|
24 |
+
|
25 |
+
try:
|
26 |
+
self.doc_store = SimpleDocumentStore.from_persist_dir(
|
27 |
+
persist_dir=str(local_data_path)
|
28 |
+
)
|
29 |
+
except FileNotFoundError:
|
30 |
+
logger.debug("Local document store not found, creating a new one")
|
31 |
+
self.doc_store = SimpleDocumentStore()
|
components/vector_store/__init__.py
ADDED
File without changes
|
components/vector_store/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (162 Bytes). View file
|
|
components/vector_store/__pycache__/component.cpython-310.pyc
ADDED
Binary file (1.81 kB). View file
|
|
components/vector_store/component.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import typing
|
3 |
+
|
4 |
+
from llama_index import VectorStoreIndex
|
5 |
+
from llama_index.indices.vector_store import VectorIndexRetriever
|
6 |
+
from llama_index.vector_stores.types import VectorStore
|
7 |
+
|
8 |
+
from app._config import settings
|
9 |
+
from app.enums import WEAVIATE_INDEX_NAME, VectorDatabase
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class VectorStoreComponent:
|
15 |
+
vector_store: VectorStore
|
16 |
+
|
17 |
+
def __init__(self) -> None:
|
18 |
+
match settings.VECTOR_DATABASE:
|
19 |
+
case VectorDatabase.WEAVIATE:
|
20 |
+
import weaviate
|
21 |
+
from llama_index.vector_stores import WeaviateVectorStore
|
22 |
+
|
23 |
+
client = weaviate.Client(settings.WEAVIATE_CLIENT_URL)
|
24 |
+
self.vector_store = typing.cast(
|
25 |
+
VectorStore,
|
26 |
+
WeaviateVectorStore(
|
27 |
+
weaviate_client=client, index_name=WEAVIATE_INDEX_NAME
|
28 |
+
),
|
29 |
+
)
|
30 |
+
case _:
|
31 |
+
# Should be unreachable
|
32 |
+
# The settings validator should have caught this
|
33 |
+
raise ValueError(
|
34 |
+
f"Vectorstore database {settings.VECTOR_DATABASE} not supported"
|
35 |
+
)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def get_retriever(
|
39 |
+
index: VectorStoreIndex,
|
40 |
+
doc_ids: list[str] | None = None,
|
41 |
+
similarity_top_k: int = 2,
|
42 |
+
) -> VectorIndexRetriever:
|
43 |
+
return VectorIndexRetriever(
|
44 |
+
index=index,
|
45 |
+
similarity_top_k=similarity_top_k,
|
46 |
+
doc_ids=doc_ids,
|
47 |
+
)
|
48 |
+
|
49 |
+
def close(self) -> None:
|
50 |
+
if hasattr(self.vector_store.client, "close"):
|
51 |
+
self.vector_store.client.close()
|
enums.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum, auto, unique
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
|
5 |
+
|
6 |
+
|
7 |
+
@unique
|
8 |
+
class BaseEnum(str, Enum):
|
9 |
+
@staticmethod
|
10 |
+
def _generate_next_value_(name: str, *_):
|
11 |
+
"""
|
12 |
+
Automatically generate values for enum.
|
13 |
+
Enum values are lower-cased enum member names.
|
14 |
+
"""
|
15 |
+
return name.lower()
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def get_values(cls) -> list[str]:
|
19 |
+
# noinspection PyUnresolvedReferences
|
20 |
+
return [m.value for m in cls]
|
21 |
+
|
22 |
+
|
23 |
+
class LLMMode(BaseEnum):
|
24 |
+
MOCK = auto()
|
25 |
+
OPENAI = auto()
|
26 |
+
LOCAL = auto()
|
27 |
+
|
28 |
+
|
29 |
+
class EmbeddingMode(BaseEnum):
|
30 |
+
MOCK = auto()
|
31 |
+
OPENAI = auto()
|
32 |
+
LOCAL = auto()
|
33 |
+
|
34 |
+
|
35 |
+
class VectorDatabase(BaseEnum):
|
36 |
+
WEAVIATE = auto()
|
37 |
+
|
38 |
+
|
39 |
+
WEAVIATE_INDEX_NAME = "LlamaIndex"
|
main.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
|
5 |
+
from app._config import settings
|
6 |
+
from app.components.embedding.component import EmbeddingComponent
|
7 |
+
from app.components.llm.component import LLMComponent
|
8 |
+
from app.components.node_store.component import NodeStoreComponent
|
9 |
+
from app.components.vector_store.component import VectorStoreComponent
|
10 |
+
from app.server.chat.router import chat_router
|
11 |
+
from app.server.chat.service import ChatService
|
12 |
+
from app.server.embedding.router import embedding_router
|
13 |
+
from app.server.ingest.service import IngestService
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
app = FastAPI()
|
18 |
+
app.include_router(chat_router)
|
19 |
+
app.include_router(embedding_router)
|
20 |
+
|
21 |
+
if settings.IS_UI_ENABLED:
|
22 |
+
logger.debug("Importing the UI module")
|
23 |
+
from app.ui.ui import PrivateGptUi
|
24 |
+
|
25 |
+
llm_component = LLMComponent()
|
26 |
+
vector_store_component = VectorStoreComponent()
|
27 |
+
embedding_component = EmbeddingComponent()
|
28 |
+
node_store_component = NodeStoreComponent()
|
29 |
+
|
30 |
+
ingest_service = IngestService(
|
31 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
32 |
+
)
|
33 |
+
chat_service = ChatService(
|
34 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
35 |
+
)
|
36 |
+
|
37 |
+
ui = PrivateGptUi(ingest_service, chat_service)
|
38 |
+
ui.mount_in_app(app, settings.UI_PATH)
|
paths.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from app._config import settings
|
4 |
+
from app.enums import PROJECT_ROOT_PATH
|
5 |
+
|
6 |
+
|
7 |
+
def _absolute_or_from_project_root(path: str) -> Path:
|
8 |
+
if path.startswith("/"):
|
9 |
+
return Path(path)
|
10 |
+
return PROJECT_ROOT_PATH / path
|
11 |
+
|
12 |
+
|
13 |
+
local_data_path: Path = _absolute_or_from_project_root(settings.LOCAL_DATA_FOLDER)
|
14 |
+
models_path: Path = PROJECT_ROOT_PATH / "models"
|
15 |
+
models_cache_path: Path = models_path / "cache"
|
server/__init__.py
ADDED
File without changes
|
server/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (145 Bytes). View file
|
|
server/chat/__init__.py
ADDED
File without changes
|
server/chat/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
server/chat/__pycache__/router.cpython-310.pyc
ADDED
Binary file (2.36 kB). View file
|
|
server/chat/__pycache__/schemas.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
server/chat/__pycache__/service.cpython-310.pyc
ADDED
Binary file (3.82 kB). View file
|
|
server/chat/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.41 kB). View file
|
|
server/chat/router.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
from llama_index.llms import ChatMessage, MessageRole
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from app.components.embedding.component import EmbeddingComponent
|
6 |
+
from app.components.llm.component import LLMComponent
|
7 |
+
from app.components.node_store.component import NodeStoreComponent
|
8 |
+
from app.components.vector_store.component import VectorStoreComponent
|
9 |
+
from app.server.chat.service import ChatService
|
10 |
+
from app.server.chat.utils import OpenAICompletion, OpenAIMessage, to_openai_response
|
11 |
+
|
12 |
+
chat_router = APIRouter()
|
13 |
+
|
14 |
+
|
15 |
+
class ChatBody(BaseModel):
|
16 |
+
messages: list[OpenAIMessage]
|
17 |
+
include_sources: bool = True
|
18 |
+
|
19 |
+
model_config = {
|
20 |
+
"json_schema_extra": {
|
21 |
+
"examples": [
|
22 |
+
{
|
23 |
+
"messages": [
|
24 |
+
{
|
25 |
+
"role": "system",
|
26 |
+
"content": "You are a rapper. Always answer with a rap.",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"role": "user",
|
30 |
+
"content": "How do you fry an egg?",
|
31 |
+
},
|
32 |
+
],
|
33 |
+
"include_sources": True,
|
34 |
+
}
|
35 |
+
]
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
@chat_router.post(
|
41 |
+
"/chat",
|
42 |
+
response_model=None,
|
43 |
+
responses={200: {"model": OpenAICompletion}},
|
44 |
+
tags=["Contextual Completions"],
|
45 |
+
)
|
46 |
+
def chat_completion(body: ChatBody) -> OpenAICompletion:
|
47 |
+
"""Given a list of messages comprising a conversation, return a response.
|
48 |
+
|
49 |
+
Optionally include an initial `role: system` message to influence the way
|
50 |
+
the LLM answers.
|
51 |
+
|
52 |
+
When using `'include_sources': true`, the API will return the source Chunks used
|
53 |
+
to create the response, which come from the context provided.
|
54 |
+
"""
|
55 |
+
llm_component = LLMComponent()
|
56 |
+
vector_store_component = VectorStoreComponent()
|
57 |
+
embedding_component = EmbeddingComponent()
|
58 |
+
node_store_component = NodeStoreComponent()
|
59 |
+
|
60 |
+
chat_service = ChatService(
|
61 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
62 |
+
)
|
63 |
+
all_messages = [
|
64 |
+
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
|
65 |
+
]
|
66 |
+
|
67 |
+
completion = chat_service.chat(messages=all_messages)
|
68 |
+
return to_openai_response(
|
69 |
+
completion.response, completion.sources if body.include_sources else None
|
70 |
+
)
|
server/chat/schemas.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
from llama_index.schema import NodeWithScore
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from app.server.ingest.schemas import IngestedDoc
|
7 |
+
|
8 |
+
|
9 |
+
class Chunk(BaseModel):
|
10 |
+
object: Literal["context.chunk"]
|
11 |
+
score: float = Field(examples=[0.023])
|
12 |
+
document: IngestedDoc
|
13 |
+
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
|
14 |
+
previous_texts: list[str] | None = Field(
|
15 |
+
default=None,
|
16 |
+
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
|
17 |
+
)
|
18 |
+
next_texts: list[str] | None = Field(
|
19 |
+
default=None,
|
20 |
+
examples=[
|
21 |
+
[
|
22 |
+
"New leads came from Google Ads campaign.",
|
23 |
+
"The campaign was run by the Marketing Department",
|
24 |
+
]
|
25 |
+
],
|
26 |
+
)
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
|
30 |
+
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
|
31 |
+
return cls(
|
32 |
+
object="context.chunk",
|
33 |
+
score=node.score or 0.0,
|
34 |
+
document=IngestedDoc(
|
35 |
+
object="ingest.document",
|
36 |
+
doc_id=doc_id,
|
37 |
+
doc_metadata=node.metadata,
|
38 |
+
),
|
39 |
+
text=node.get_content(),
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class Completion(BaseModel):
|
44 |
+
response: str
|
45 |
+
sources: list[Chunk] | None = None
|
server/chat/service.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
4 |
+
from llama_index.chat_engine import ContextChatEngine
|
5 |
+
from llama_index.chat_engine.types import BaseChatEngine
|
6 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
7 |
+
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
|
8 |
+
from llama_index.llms import ChatMessage, MessageRole
|
9 |
+
|
10 |
+
from app._config import settings
|
11 |
+
from app.components.embedding.component import EmbeddingComponent
|
12 |
+
from app.components.llm.component import LLMComponent
|
13 |
+
from app.components.node_store.component import NodeStoreComponent
|
14 |
+
from app.components.vector_store.component import VectorStoreComponent
|
15 |
+
from app.server.chat.schemas import Chunk, Completion
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class ChatEngineInput:
|
20 |
+
system_message: ChatMessage | None = None
|
21 |
+
last_message: ChatMessage | None = None
|
22 |
+
chat_history: list[ChatMessage] | None = None
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
|
26 |
+
# Detect if there is a system message, extract the last message and chat history
|
27 |
+
system_message = (
|
28 |
+
messages[0]
|
29 |
+
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
|
30 |
+
else None
|
31 |
+
)
|
32 |
+
last_message = (
|
33 |
+
messages[-1]
|
34 |
+
if len(messages) > 0 and messages[-1].role == MessageRole.USER
|
35 |
+
else None
|
36 |
+
)
|
37 |
+
# Remove from messages list the system message and last message,
|
38 |
+
# if they exist. The rest is the chat history.
|
39 |
+
if system_message:
|
40 |
+
messages.pop(0)
|
41 |
+
if last_message:
|
42 |
+
messages.pop(-1)
|
43 |
+
chat_history = messages if len(messages) > 0 else None
|
44 |
+
|
45 |
+
return cls(
|
46 |
+
system_message=system_message,
|
47 |
+
last_message=last_message,
|
48 |
+
chat_history=chat_history,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class ChatService:
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
llm_component: LLMComponent,
|
56 |
+
vector_store_component: VectorStoreComponent,
|
57 |
+
embedding_component: EmbeddingComponent,
|
58 |
+
node_store_component: NodeStoreComponent,
|
59 |
+
) -> None:
|
60 |
+
self.llm_service = llm_component
|
61 |
+
self.vector_store_component = vector_store_component
|
62 |
+
self.storage_context = StorageContext.from_defaults(
|
63 |
+
vector_store=vector_store_component.vector_store,
|
64 |
+
docstore=node_store_component.doc_store,
|
65 |
+
index_store=node_store_component.index_store,
|
66 |
+
)
|
67 |
+
self.service_context = ServiceContext.from_defaults(
|
68 |
+
llm=llm_component.llm, embed_model=embedding_component.embedding_model
|
69 |
+
)
|
70 |
+
self.index = VectorStoreIndex.from_vector_store(
|
71 |
+
vector_store_component.vector_store,
|
72 |
+
storage_context=self.storage_context,
|
73 |
+
service_context=self.service_context,
|
74 |
+
show_progress=True,
|
75 |
+
)
|
76 |
+
|
77 |
+
def _chat_engine(self, system_prompt: str | None = None) -> BaseChatEngine:
|
78 |
+
vector_index_retriever = self.vector_store_component.get_retriever(
|
79 |
+
index=self.index
|
80 |
+
)
|
81 |
+
|
82 |
+
node_postprocessors = [
|
83 |
+
MetadataReplacementPostProcessor(target_metadata_key="window")
|
84 |
+
]
|
85 |
+
if settings.IS_RERANK_ENABLED:
|
86 |
+
rerank = SentenceTransformerRerank(
|
87 |
+
top_n=settings.RERANK_TOP_N, model=settings.RERANK_MODEL_NAME
|
88 |
+
)
|
89 |
+
node_postprocessors.append(rerank)
|
90 |
+
|
91 |
+
return ContextChatEngine.from_defaults(
|
92 |
+
system_prompt=system_prompt,
|
93 |
+
retriever=vector_index_retriever,
|
94 |
+
service_context=self.service_context,
|
95 |
+
node_postprocessors=node_postprocessors,
|
96 |
+
)
|
97 |
+
|
98 |
+
def chat(self, messages: list[ChatMessage]):
|
99 |
+
chat_engine_input = ChatEngineInput.from_messages(messages)
|
100 |
+
last_message = (
|
101 |
+
chat_engine_input.last_message.content
|
102 |
+
if chat_engine_input.last_message
|
103 |
+
else None
|
104 |
+
)
|
105 |
+
system_prompt = (
|
106 |
+
chat_engine_input.system_message.content
|
107 |
+
if chat_engine_input.system_message
|
108 |
+
else None
|
109 |
+
)
|
110 |
+
chat_history = (
|
111 |
+
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
112 |
+
)
|
113 |
+
|
114 |
+
chat_engine = self._chat_engine(system_prompt=system_prompt)
|
115 |
+
wrapped_response = chat_engine.chat(
|
116 |
+
message=last_message if last_message is not None else "",
|
117 |
+
chat_history=chat_history,
|
118 |
+
)
|
119 |
+
|
120 |
+
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
121 |
+
completion = Completion(response=wrapped_response.response, sources=sources)
|
122 |
+
return completion
|
server/chat/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from llama_index.llms import ChatResponse
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from app.server.chat.schemas import Chunk
|
9 |
+
|
10 |
+
|
11 |
+
class OpenAIMessage(BaseModel):
|
12 |
+
"""Inference result, with the source of the message.
|
13 |
+
|
14 |
+
Role could be the assistant or system
|
15 |
+
(providing a default response, not AI generated).
|
16 |
+
"""
|
17 |
+
|
18 |
+
role: Literal["assistant", "system", "user"] = Field(default="user")
|
19 |
+
content: str | None
|
20 |
+
|
21 |
+
|
22 |
+
class OpenAIChoice(BaseModel):
|
23 |
+
"""Response from AI."""
|
24 |
+
|
25 |
+
finish_reason: str | None = Field(examples=["stop"])
|
26 |
+
message: OpenAIMessage | None = None
|
27 |
+
sources: list[Chunk] | None = None
|
28 |
+
index: int = 0
|
29 |
+
|
30 |
+
|
31 |
+
class OpenAICompletion(BaseModel):
|
32 |
+
"""Clone of OpenAI Completion model.
|
33 |
+
|
34 |
+
For more information see: https://platform.openai.com/docs/api-reference/chat/object
|
35 |
+
"""
|
36 |
+
|
37 |
+
id: str
|
38 |
+
object: Literal["completion", "completion.chunk"] = Field(default="completion")
|
39 |
+
created: int = Field(..., examples=[1623340000])
|
40 |
+
model: Literal["llm-agriculture"]
|
41 |
+
choices: list[OpenAIChoice]
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_text(
|
45 |
+
cls,
|
46 |
+
text: str | None,
|
47 |
+
finish_reason: str | None = None,
|
48 |
+
sources: list[Chunk] | None = None,
|
49 |
+
) -> "OpenAICompletion":
|
50 |
+
return OpenAICompletion(
|
51 |
+
id=str(uuid.uuid4()),
|
52 |
+
object="completion",
|
53 |
+
created=int(time.time()),
|
54 |
+
model="llm-agriculture",
|
55 |
+
choices=[
|
56 |
+
OpenAIChoice(
|
57 |
+
message=OpenAIMessage(role="assistant", content=text),
|
58 |
+
finish_reason=finish_reason,
|
59 |
+
sources=sources,
|
60 |
+
)
|
61 |
+
],
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def to_openai_response(
|
66 |
+
response: str | ChatResponse, sources: list[Chunk] | None = None
|
67 |
+
) -> OpenAICompletion:
|
68 |
+
return OpenAICompletion.from_text(response, finish_reason="stop", sources=sources)
|
server/embedding/__init__.py
ADDED
File without changes
|