Spaces:
Runtime error
Runtime error
tommymarto
commited on
Commit
•
e04cd14
1
Parent(s):
54abba0
first attempt to hf spaces
Browse files- config/config.yaml +1 -1
- config/document_retriever/multiquery_retriever.yaml +1 -0
- config/gradio_config.yaml +27 -0
- data +1 -0
- src/demo.py +1 -1
- src/document_retriever/multiquery_retriever.py +37 -0
- src/gradio.py +0 -17
- src/gradio_app.py +68 -0
- src/llm4scilit_gradio_interface.py +508 -0
- src/question_answering/huggingface.py +7 -5
config/config.yaml
CHANGED
@@ -3,7 +3,7 @@ defaults:
|
|
3 |
- text_splitter: spacy
|
4 |
- text_embedding: huggingface
|
5 |
- vector_store: faiss
|
6 |
-
- document_retriever:
|
7 |
- question_answering: huggingface
|
8 |
- _self_
|
9 |
- override hydra/hydra_logging: disabled
|
|
|
3 |
- text_splitter: spacy
|
4 |
- text_embedding: huggingface
|
5 |
- vector_store: faiss
|
6 |
+
- document_retriever: multiquery_retriever
|
7 |
- question_answering: huggingface
|
8 |
- _self_
|
9 |
- override hydra/hydra_logging: disabled
|
config/document_retriever/multiquery_retriever.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
_target_: document_retriever.multiquery_retriever.MultiQueryDocumentRetriever
|
config/gradio_config.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- document_loader: grobid
|
3 |
+
- text_splitter: spacy
|
4 |
+
- text_embedding: huggingface
|
5 |
+
- vector_store: faiss
|
6 |
+
- document_retriever: simple_retriever
|
7 |
+
- question_answering: huggingface
|
8 |
+
- _self_
|
9 |
+
- override hydra/hydra_logging: disabled
|
10 |
+
- override hydra/job_logging: disabled
|
11 |
+
|
12 |
+
storage_path:
|
13 |
+
base: ./data
|
14 |
+
documents: ${storage_path.base}/papers
|
15 |
+
documents_processed: ${storage_path.documents}_processed
|
16 |
+
vector_store: ${storage_path.base}/vector_store
|
17 |
+
|
18 |
+
mode: interactive
|
19 |
+
debug:
|
20 |
+
is_debug: false
|
21 |
+
force_rebuild_storage: false
|
22 |
+
|
23 |
+
document_parsing:
|
24 |
+
enabled: false
|
25 |
+
|
26 |
+
hydra:
|
27 |
+
verbose: false
|
data
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/data/tommaso/llm4scilit/data/
|
src/demo.py
CHANGED
@@ -114,7 +114,7 @@ class App:
|
|
114 |
|
115 |
def ask_chat(self, line, history):
|
116 |
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
|
117 |
-
return self.qa_model.answer_question(line, {})
|
118 |
|
119 |
|
120 |
##################################################################################################
|
|
|
114 |
|
115 |
def ask_chat(self, line, history):
|
116 |
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
|
117 |
+
return self.qa_model.answer_question(line, {})
|
118 |
|
119 |
|
120 |
##################################################################################################
|
src/document_retriever/multiquery_retriever.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
2 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
3 |
+
|
4 |
+
# Set logging for the queries
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logging.basicConfig()
|
8 |
+
|
9 |
+
|
10 |
+
class MultiQueryDocumentRetriever:
|
11 |
+
def __init__(self, vector_store):
|
12 |
+
self.vector_store = vector_store
|
13 |
+
self.retriever = None
|
14 |
+
self.llm = None
|
15 |
+
# self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd"
|
16 |
+
|
17 |
+
def initialize(self):
|
18 |
+
# self.llama = LlamaAPI(self.token)
|
19 |
+
self.llm = HuggingFacePipeline.from_model_id(
|
20 |
+
# model_id="bigscience/bloom-1b7",
|
21 |
+
model_id="bigscience/bloomz-1b7",
|
22 |
+
task="text-generation",
|
23 |
+
# device=1,
|
24 |
+
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
25 |
+
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
26 |
+
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
27 |
+
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
28 |
+
)
|
29 |
+
|
30 |
+
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
|
31 |
+
self.retriever = MultiQueryRetriever.from_llm(
|
32 |
+
retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}),
|
33 |
+
llm=self.llm
|
34 |
+
)
|
35 |
+
|
36 |
+
def retrieve(self, query: str, k: int = 4):
|
37 |
+
pass
|
src/gradio.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from hydra import compose, initialize
|
3 |
-
from omegaconf import OmegaConf
|
4 |
-
|
5 |
-
from demo import App
|
6 |
-
|
7 |
-
def main():
|
8 |
-
with initialize(version_base=None, config_path="../config", job_name="gradio_app"):
|
9 |
-
cfg = compose(config_name="config", overrides=["document_parsing.enabled=False"])
|
10 |
-
|
11 |
-
app = App(cfg)
|
12 |
-
|
13 |
-
webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
14 |
-
webapp.launch(share=True)
|
15 |
-
|
16 |
-
if __name__ == "__main__":
|
17 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/gradio_app.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
from demo import App
|
4 |
+
|
5 |
+
from llm4scilit_gradio_interface import LLM4SciLitChatInterface
|
6 |
+
|
7 |
+
def echo(text, history):
|
8 |
+
asdf = "asdf"
|
9 |
+
values = [f"{x}\n{x*2}" for x in asdf]
|
10 |
+
return text, *values
|
11 |
+
|
12 |
+
|
13 |
+
@hydra.main(version_base=None, config_path="../config", config_name="gradio_config")
|
14 |
+
def main(cfg : DictConfig) -> None:
|
15 |
+
cfg.document_parsing['enabled'] = False
|
16 |
+
|
17 |
+
app = App(cfg)
|
18 |
+
app._bootstrap()
|
19 |
+
|
20 |
+
def wrapped_ask_chat(text, history):
|
21 |
+
result = app.ask_chat(text, history)
|
22 |
+
sources = [
|
23 |
+
f"{x.metadata['paper_title']}\n{x.page_content}"
|
24 |
+
for x in result['source_documents']
|
25 |
+
]
|
26 |
+
return result['result'], *sources
|
27 |
+
|
28 |
+
|
29 |
+
LLM4SciLitChatInterface(wrapped_ask_chat, title="LLM4SciLit").launch()
|
30 |
+
# LLM4SciLitChatInterface(echo, title="LLM4SciLit").launch()
|
31 |
+
|
32 |
+
# textbox = gr.Textbox(placeholder="Ask a question about scientific literature", lines=2, label="Question", elem_id="textbox")
|
33 |
+
# chatbot = gr.Chatbot(label="LLM4SciLit", elem_id="chat")
|
34 |
+
# gr.Interface(fn=echo, inputs=[textbox, chatbot], outputs=[chatbot], title="LLM4SciLit").launch()
|
35 |
+
|
36 |
+
# with gr.Blocks() as demo:
|
37 |
+
# chatbot = gr.Chatbot()
|
38 |
+
# msg = gr.Textbox(container=False)
|
39 |
+
# clear = gr.ClearButton([msg, chatbot])
|
40 |
+
|
41 |
+
# def respond(message, chat_history):
|
42 |
+
# bot_message = "How are you?"
|
43 |
+
# chat_history.append((message, bot_message))
|
44 |
+
# return "", chat_history
|
45 |
+
|
46 |
+
# msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
# with gr.Blocks(title="LLM4SciLit") as demo:
|
51 |
+
# with gr.Row():
|
52 |
+
# with gr.Column(scale=5):
|
53 |
+
# with gr.Row():
|
54 |
+
# gr.Chatbot(fn=echo)
|
55 |
+
# with gr.Row():
|
56 |
+
# gr.Button("Submit")
|
57 |
+
|
58 |
+
# with gr.Column(scale=5):
|
59 |
+
# with gr.Accordion("Retrieved documents"):
|
60 |
+
# gr.Label("Document 1")
|
61 |
+
|
62 |
+
# webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
63 |
+
# webapp = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
64 |
+
# demo.launch()
|
65 |
+
# webapp.launch(share=True)
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main() # pylint: disable=no-value-for-parameter
|
src/llm4scilit_gradio_interface.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
import inspect
|
9 |
+
from typing import AsyncGenerator, Callable
|
10 |
+
|
11 |
+
import anyio
|
12 |
+
from gradio_client import utils as client_utils
|
13 |
+
from gradio_client.documentation import document, set_documentation_group
|
14 |
+
|
15 |
+
from gradio.blocks import Blocks
|
16 |
+
from gradio.components import (
|
17 |
+
Button,
|
18 |
+
Chatbot,
|
19 |
+
IOComponent,
|
20 |
+
Markdown,
|
21 |
+
State,
|
22 |
+
Textbox,
|
23 |
+
get_component_instance,
|
24 |
+
)
|
25 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
26 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
27 |
+
from gradio.layouts import Accordion, Column, Group, Row
|
28 |
+
from gradio.themes import ThemeClass as Theme
|
29 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
30 |
+
|
31 |
+
set_documentation_group("chatinterface")
|
32 |
+
|
33 |
+
|
34 |
+
@document()
|
35 |
+
class LLM4SciLitChatInterface(Blocks):
|
36 |
+
"""
|
37 |
+
ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
|
38 |
+
a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
|
39 |
+
takes a function that governs the response of the chatbot based on the user input and chat history. Additional
|
40 |
+
parameters can be used to control the appearance and behavior of the demo.
|
41 |
+
|
42 |
+
Example:
|
43 |
+
import gradio as gr
|
44 |
+
|
45 |
+
def echo(message, history):
|
46 |
+
return message
|
47 |
+
|
48 |
+
demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
|
49 |
+
demo.launch()
|
50 |
+
Demos: chatinterface_random_response, chatinterface_streaming_echo
|
51 |
+
Guides: creating-a-chatbot-fast, sharing-your-app
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
fn: Callable,
|
57 |
+
*,
|
58 |
+
chatbot: Chatbot | None = None,
|
59 |
+
textbox: Textbox | None = None,
|
60 |
+
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
|
61 |
+
additional_inputs_accordion_name: str = "Additional Inputs",
|
62 |
+
examples: list[str] | None = None,
|
63 |
+
cache_examples: bool | None = None,
|
64 |
+
title: str | None = None,
|
65 |
+
description: str | None = None,
|
66 |
+
theme: Theme | str | None = None,
|
67 |
+
css: str | None = None,
|
68 |
+
analytics_enabled: bool | None = None,
|
69 |
+
submit_btn: str | None | Button = "Submit",
|
70 |
+
stop_btn: str | None | Button = "Stop",
|
71 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
72 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
73 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
74 |
+
autofocus: bool = True,
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
Parameters:
|
78 |
+
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
79 |
+
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
80 |
+
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
81 |
+
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
82 |
+
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
|
83 |
+
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
84 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
85 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
86 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
87 |
+
theme: Theme to use, loaded from gradio.themes.
|
88 |
+
css: custom css or path to custom css file to use with interface.
|
89 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
90 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
91 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
92 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
93 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
94 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
95 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
96 |
+
"""
|
97 |
+
super().__init__(
|
98 |
+
analytics_enabled=analytics_enabled,
|
99 |
+
mode="chat_interface",
|
100 |
+
css=css,
|
101 |
+
title=title or "Gradio",
|
102 |
+
theme=theme,
|
103 |
+
)
|
104 |
+
self.fn = fn
|
105 |
+
self.is_async = inspect.iscoroutinefunction(
|
106 |
+
self.fn
|
107 |
+
) or inspect.isasyncgenfunction(self.fn)
|
108 |
+
self.is_generator = inspect.isgeneratorfunction(
|
109 |
+
self.fn
|
110 |
+
) or inspect.isasyncgenfunction(self.fn)
|
111 |
+
self.examples = examples
|
112 |
+
if self.space_id and cache_examples is None:
|
113 |
+
self.cache_examples = True
|
114 |
+
else:
|
115 |
+
self.cache_examples = cache_examples or False
|
116 |
+
self.buttons: list[Button] = []
|
117 |
+
|
118 |
+
if additional_inputs:
|
119 |
+
if not isinstance(additional_inputs, list):
|
120 |
+
additional_inputs = [additional_inputs]
|
121 |
+
self.additional_inputs = [
|
122 |
+
get_component_instance(i) for i in additional_inputs # type: ignore
|
123 |
+
]
|
124 |
+
else:
|
125 |
+
self.additional_inputs = []
|
126 |
+
self.additional_inputs_accordion_name = additional_inputs_accordion_name
|
127 |
+
|
128 |
+
self.additional_outputs = []
|
129 |
+
|
130 |
+
with self:
|
131 |
+
if title:
|
132 |
+
Markdown(
|
133 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
134 |
+
)
|
135 |
+
if description:
|
136 |
+
Markdown(description)
|
137 |
+
|
138 |
+
with Row():
|
139 |
+
with Column(variant="panel", scale=1):
|
140 |
+
if chatbot:
|
141 |
+
self.chatbot = chatbot.render()
|
142 |
+
else:
|
143 |
+
self.chatbot = Chatbot(label="Chatbot")
|
144 |
+
|
145 |
+
with Group():
|
146 |
+
with Row():
|
147 |
+
if textbox:
|
148 |
+
textbox.container = False
|
149 |
+
textbox.show_label = False
|
150 |
+
self.textbox = textbox.render()
|
151 |
+
else:
|
152 |
+
self.textbox = Textbox(
|
153 |
+
container=False,
|
154 |
+
show_label=False,
|
155 |
+
label="Message",
|
156 |
+
placeholder="Type a message...",
|
157 |
+
scale=7,
|
158 |
+
autofocus=autofocus,
|
159 |
+
)
|
160 |
+
if submit_btn:
|
161 |
+
if isinstance(submit_btn, Button):
|
162 |
+
submit_btn.render()
|
163 |
+
elif isinstance(submit_btn, str):
|
164 |
+
submit_btn = Button(
|
165 |
+
submit_btn,
|
166 |
+
variant="primary",
|
167 |
+
scale=1,
|
168 |
+
min_width=150,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
raise ValueError(
|
172 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
173 |
+
)
|
174 |
+
if stop_btn:
|
175 |
+
if isinstance(stop_btn, Button):
|
176 |
+
stop_btn.visible = False
|
177 |
+
stop_btn.render()
|
178 |
+
elif isinstance(stop_btn, str):
|
179 |
+
stop_btn = Button(
|
180 |
+
stop_btn,
|
181 |
+
variant="stop",
|
182 |
+
visible=False,
|
183 |
+
scale=1,
|
184 |
+
min_width=150,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise ValueError(
|
188 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
189 |
+
)
|
190 |
+
self.buttons.extend([submit_btn, stop_btn])
|
191 |
+
|
192 |
+
with Row():
|
193 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
194 |
+
if btn:
|
195 |
+
if isinstance(btn, Button):
|
196 |
+
btn.render()
|
197 |
+
elif isinstance(btn, str):
|
198 |
+
btn = Button(btn, variant="secondary")
|
199 |
+
else:
|
200 |
+
raise ValueError(
|
201 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
202 |
+
)
|
203 |
+
self.buttons.append(btn)
|
204 |
+
|
205 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
206 |
+
self.fake_response_textbox = Textbox(
|
207 |
+
label="Response", visible=False
|
208 |
+
)
|
209 |
+
(
|
210 |
+
self.submit_btn,
|
211 |
+
self.stop_btn,
|
212 |
+
self.retry_btn,
|
213 |
+
self.undo_btn,
|
214 |
+
self.clear_btn,
|
215 |
+
) = self.buttons
|
216 |
+
|
217 |
+
with Column(variant="panel", scale=2):
|
218 |
+
for i in range(4):
|
219 |
+
self.additional_outputs.append(
|
220 |
+
Textbox(
|
221 |
+
interactive=False,
|
222 |
+
label=f"Document {i+1}"
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
if examples:
|
227 |
+
if self.is_generator:
|
228 |
+
examples_fn = self._examples_stream_fn
|
229 |
+
else:
|
230 |
+
examples_fn = self._examples_fn
|
231 |
+
|
232 |
+
self.examples_handler = Examples(
|
233 |
+
examples=examples,
|
234 |
+
inputs=[self.textbox] + self.additional_inputs,
|
235 |
+
outputs=self.chatbot,
|
236 |
+
fn=examples_fn,
|
237 |
+
)
|
238 |
+
|
239 |
+
any_unrendered_inputs = any(
|
240 |
+
not inp.is_rendered for inp in self.additional_inputs
|
241 |
+
)
|
242 |
+
if self.additional_inputs and any_unrendered_inputs:
|
243 |
+
with Accordion(self.additional_inputs_accordion_name, open=False):
|
244 |
+
for input_component in self.additional_inputs:
|
245 |
+
if not input_component.is_rendered:
|
246 |
+
input_component.render()
|
247 |
+
|
248 |
+
# The example caching must happen after the input components have rendered
|
249 |
+
if cache_examples:
|
250 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
251 |
+
|
252 |
+
self.saved_input = State()
|
253 |
+
self.chatbot_state = State([])
|
254 |
+
|
255 |
+
self._setup_events()
|
256 |
+
self._setup_api()
|
257 |
+
|
258 |
+
def _setup_events(self) -> None:
|
259 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
260 |
+
submit_triggers = (
|
261 |
+
[self.textbox.submit, self.submit_btn.click]
|
262 |
+
if self.submit_btn
|
263 |
+
else [self.textbox.submit]
|
264 |
+
)
|
265 |
+
submit_event = (
|
266 |
+
on(
|
267 |
+
submit_triggers,
|
268 |
+
self._clear_and_save_textbox,
|
269 |
+
[self.textbox],
|
270 |
+
[self.textbox, self.saved_input],
|
271 |
+
api_name=False,
|
272 |
+
queue=False,
|
273 |
+
)
|
274 |
+
.then(
|
275 |
+
self._display_input,
|
276 |
+
[self.saved_input, self.chatbot_state],
|
277 |
+
[self.chatbot, self.chatbot_state],
|
278 |
+
api_name=False,
|
279 |
+
queue=False,
|
280 |
+
)
|
281 |
+
.then(
|
282 |
+
submit_fn,
|
283 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
284 |
+
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
285 |
+
api_name=False,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
289 |
+
|
290 |
+
if self.retry_btn:
|
291 |
+
retry_event = (
|
292 |
+
self.retry_btn.click(
|
293 |
+
self._delete_prev_fn,
|
294 |
+
[self.chatbot_state],
|
295 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
296 |
+
api_name=False,
|
297 |
+
queue=False,
|
298 |
+
)
|
299 |
+
.then(
|
300 |
+
self._display_input,
|
301 |
+
[self.saved_input, self.chatbot_state],
|
302 |
+
[self.chatbot, self.chatbot_state],
|
303 |
+
api_name=False,
|
304 |
+
queue=False,
|
305 |
+
)
|
306 |
+
.then(
|
307 |
+
submit_fn,
|
308 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
309 |
+
[self.chatbot, self.chatbot_state],
|
310 |
+
api_name=False,
|
311 |
+
)
|
312 |
+
)
|
313 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
314 |
+
|
315 |
+
if self.undo_btn:
|
316 |
+
self.undo_btn.click(
|
317 |
+
self._delete_prev_fn,
|
318 |
+
[self.chatbot_state],
|
319 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
320 |
+
api_name=False,
|
321 |
+
queue=False,
|
322 |
+
).then(
|
323 |
+
lambda x: x,
|
324 |
+
[self.saved_input],
|
325 |
+
[self.textbox],
|
326 |
+
api_name=False,
|
327 |
+
queue=False,
|
328 |
+
)
|
329 |
+
|
330 |
+
if self.clear_btn:
|
331 |
+
self.clear_btn.click(
|
332 |
+
lambda: ([], [], None),
|
333 |
+
None,
|
334 |
+
[self.chatbot, self.chatbot_state, self.saved_input],
|
335 |
+
queue=False,
|
336 |
+
api_name=False,
|
337 |
+
)
|
338 |
+
|
339 |
+
def _setup_stop_events(
|
340 |
+
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
341 |
+
) -> None:
|
342 |
+
if self.stop_btn and self.is_generator:
|
343 |
+
if self.submit_btn:
|
344 |
+
for event_trigger in event_triggers:
|
345 |
+
event_trigger(
|
346 |
+
lambda: (
|
347 |
+
Button.update(visible=False),
|
348 |
+
Button.update(visible=True),
|
349 |
+
),
|
350 |
+
None,
|
351 |
+
[self.submit_btn, self.stop_btn],
|
352 |
+
api_name=False,
|
353 |
+
queue=False,
|
354 |
+
)
|
355 |
+
event_to_cancel.then(
|
356 |
+
lambda: (Button.update(visible=True), Button.update(visible=False)),
|
357 |
+
None,
|
358 |
+
[self.submit_btn, self.stop_btn],
|
359 |
+
api_name=False,
|
360 |
+
queue=False,
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
for event_trigger in event_triggers:
|
364 |
+
event_trigger(
|
365 |
+
lambda: Button.update(visible=True),
|
366 |
+
None,
|
367 |
+
[self.stop_btn],
|
368 |
+
api_name=False,
|
369 |
+
queue=False,
|
370 |
+
)
|
371 |
+
event_to_cancel.then(
|
372 |
+
lambda: Button.update(visible=False),
|
373 |
+
None,
|
374 |
+
[self.stop_btn],
|
375 |
+
api_name=False,
|
376 |
+
queue=False,
|
377 |
+
)
|
378 |
+
self.stop_btn.click(
|
379 |
+
None,
|
380 |
+
None,
|
381 |
+
None,
|
382 |
+
cancels=event_to_cancel,
|
383 |
+
api_name=False,
|
384 |
+
)
|
385 |
+
|
386 |
+
def _setup_api(self) -> None:
|
387 |
+
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
388 |
+
|
389 |
+
self.fake_api_btn.click(
|
390 |
+
api_fn,
|
391 |
+
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
392 |
+
[self.textbox, self.chatbot_state],
|
393 |
+
api_name="chat",
|
394 |
+
)
|
395 |
+
|
396 |
+
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
|
397 |
+
return "", message
|
398 |
+
|
399 |
+
def _display_input(
|
400 |
+
self, message: str, history: list[list[str | None]]
|
401 |
+
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
402 |
+
history.append([message, None])
|
403 |
+
return history, history
|
404 |
+
|
405 |
+
async def _submit_fn(
|
406 |
+
self,
|
407 |
+
message: str,
|
408 |
+
history_with_input: list[list[str | None]],
|
409 |
+
*args,
|
410 |
+
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
411 |
+
history = history_with_input[:-1]
|
412 |
+
if self.is_async:
|
413 |
+
[response, *other_outputs] = await self.fn(message, history, *args)
|
414 |
+
else:
|
415 |
+
[response, *other_outputs] = await anyio.to_thread.run_sync(
|
416 |
+
self.fn, message, history, *args, limiter=self.limiter
|
417 |
+
)
|
418 |
+
history.append([message, response])
|
419 |
+
|
420 |
+
return history, history, *other_outputs
|
421 |
+
|
422 |
+
async def _stream_fn(
|
423 |
+
self,
|
424 |
+
message: str,
|
425 |
+
history_with_input: list[list[str | None]],
|
426 |
+
*args,
|
427 |
+
) -> AsyncGenerator:
|
428 |
+
history = history_with_input[:-1]
|
429 |
+
if self.is_async:
|
430 |
+
generator = self.fn(message, history, *args)
|
431 |
+
else:
|
432 |
+
generator = await anyio.to_thread.run_sync(
|
433 |
+
self.fn, message, history, *args, limiter=self.limiter
|
434 |
+
)
|
435 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
436 |
+
try:
|
437 |
+
first_response = await async_iteration(generator)
|
438 |
+
update = history + [[message, first_response]]
|
439 |
+
yield update, update
|
440 |
+
except StopIteration:
|
441 |
+
update = history + [[message, None]]
|
442 |
+
yield update, update
|
443 |
+
async for response in generator:
|
444 |
+
update = history + [[message, response]]
|
445 |
+
yield update, update
|
446 |
+
|
447 |
+
async def _api_submit_fn(
|
448 |
+
self, message: str, history: list[list[str | None]], *args
|
449 |
+
) -> tuple[str, list[list[str | None]]]:
|
450 |
+
if self.is_async:
|
451 |
+
response = await self.fn(message, history, *args)
|
452 |
+
else:
|
453 |
+
response = await anyio.to_thread.run_sync(
|
454 |
+
self.fn, message, history, *args, limiter=self.limiter
|
455 |
+
)
|
456 |
+
history.append([message, response])
|
457 |
+
return response, history
|
458 |
+
|
459 |
+
async def _api_stream_fn(
|
460 |
+
self, message: str, history: list[list[str | None]], *args
|
461 |
+
) -> AsyncGenerator:
|
462 |
+
if self.is_async:
|
463 |
+
generator = self.fn(message, history, *args)
|
464 |
+
else:
|
465 |
+
generator = await anyio.to_thread.run_sync(
|
466 |
+
self.fn, message, history, *args, limiter=self.limiter
|
467 |
+
)
|
468 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
469 |
+
try:
|
470 |
+
first_response = await async_iteration(generator)
|
471 |
+
yield first_response, history + [[message, first_response]]
|
472 |
+
except StopIteration:
|
473 |
+
yield None, history + [[message, None]]
|
474 |
+
async for response in generator:
|
475 |
+
yield response, history + [[message, response]]
|
476 |
+
|
477 |
+
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
478 |
+
if self.is_async:
|
479 |
+
response = await self.fn(message, [], *args)
|
480 |
+
else:
|
481 |
+
response = await anyio.to_thread.run_sync(
|
482 |
+
self.fn, message, [], *args, limiter=self.limiter
|
483 |
+
)
|
484 |
+
return [[message, response]]
|
485 |
+
|
486 |
+
async def _examples_stream_fn(
|
487 |
+
self,
|
488 |
+
message: str,
|
489 |
+
*args,
|
490 |
+
) -> AsyncGenerator:
|
491 |
+
if self.is_async:
|
492 |
+
generator = self.fn(message, [], *args)
|
493 |
+
else:
|
494 |
+
generator = await anyio.to_thread.run_sync(
|
495 |
+
self.fn, message, [], *args, limiter=self.limiter
|
496 |
+
)
|
497 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
498 |
+
async for response in generator:
|
499 |
+
yield [[message, response]]
|
500 |
+
|
501 |
+
def _delete_prev_fn(
|
502 |
+
self, history: list[list[str | None]]
|
503 |
+
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
|
504 |
+
try:
|
505 |
+
message, _ = history.pop()
|
506 |
+
except IndexError:
|
507 |
+
message = ""
|
508 |
+
return history, message or "", history
|
src/question_answering/huggingface.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
-
from langchain import PromptTemplate
|
2 |
from langchain.chains import RetrievalQA
|
3 |
-
from langchain.llms import HuggingFacePipeline
|
4 |
|
5 |
class HuggingFaceQuestionAnswering:
|
6 |
def __init__(self, retriever) -> None:
|
7 |
self.retriever = retriever
|
8 |
self.llm = HuggingFacePipeline.from_model_id(
|
9 |
# model_id="bigscience/bloom-1b7",
|
10 |
-
model_id="bigscience/bloomz-
|
11 |
task="text-generation",
|
12 |
-
device=1,
|
13 |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
14 |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
15 |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
@@ -27,6 +27,7 @@ class HuggingFaceQuestionAnswering:
|
|
27 |
|
28 |
def answer_question(self, question: str, filter_dict):
|
29 |
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
|
|
|
30 |
|
31 |
try:
|
32 |
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
|
@@ -36,5 +37,6 @@ class HuggingFaceQuestionAnswering:
|
|
36 |
Retrieved Documents:
|
37 |
{docs if docs != "" else "No documents found."}""")
|
38 |
return result
|
39 |
-
except:
|
|
|
40 |
return {"result": "Error generating answer."}
|
|
|
1 |
+
from langchain.prompts.prompt import PromptTemplate
|
2 |
from langchain.chains import RetrievalQA
|
3 |
+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
4 |
|
5 |
class HuggingFaceQuestionAnswering:
|
6 |
def __init__(self, retriever) -> None:
|
7 |
self.retriever = retriever
|
8 |
self.llm = HuggingFacePipeline.from_model_id(
|
9 |
# model_id="bigscience/bloom-1b7",
|
10 |
+
model_id="bigscience/bloomz-1b7",
|
11 |
task="text-generation",
|
12 |
+
# device=1,
|
13 |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
14 |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
15 |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
|
|
27 |
|
28 |
def answer_question(self, question: str, filter_dict):
|
29 |
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
|
30 |
+
# retriever = self.retriever.retriever
|
31 |
|
32 |
try:
|
33 |
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
|
|
|
37 |
Retrieved Documents:
|
38 |
{docs if docs != "" else "No documents found."}""")
|
39 |
return result
|
40 |
+
except Exception as e:
|
41 |
+
print(e)
|
42 |
return {"result": "Error generating answer."}
|