tommymarto commited on
Commit
e04cd14
1 Parent(s): 54abba0

first attempt to hf spaces

Browse files
config/config.yaml CHANGED
@@ -3,7 +3,7 @@ defaults:
3
  - text_splitter: spacy
4
  - text_embedding: huggingface
5
  - vector_store: faiss
6
- - document_retriever: simple_retriever
7
  - question_answering: huggingface
8
  - _self_
9
  - override hydra/hydra_logging: disabled
 
3
  - text_splitter: spacy
4
  - text_embedding: huggingface
5
  - vector_store: faiss
6
+ - document_retriever: multiquery_retriever
7
  - question_answering: huggingface
8
  - _self_
9
  - override hydra/hydra_logging: disabled
config/document_retriever/multiquery_retriever.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ _target_: document_retriever.multiquery_retriever.MultiQueryDocumentRetriever
config/gradio_config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - document_loader: grobid
3
+ - text_splitter: spacy
4
+ - text_embedding: huggingface
5
+ - vector_store: faiss
6
+ - document_retriever: simple_retriever
7
+ - question_answering: huggingface
8
+ - _self_
9
+ - override hydra/hydra_logging: disabled
10
+ - override hydra/job_logging: disabled
11
+
12
+ storage_path:
13
+ base: ./data
14
+ documents: ${storage_path.base}/papers
15
+ documents_processed: ${storage_path.documents}_processed
16
+ vector_store: ${storage_path.base}/vector_store
17
+
18
+ mode: interactive
19
+ debug:
20
+ is_debug: false
21
+ force_rebuild_storage: false
22
+
23
+ document_parsing:
24
+ enabled: false
25
+
26
+ hydra:
27
+ verbose: false
data ADDED
@@ -0,0 +1 @@
 
 
1
+ /data/tommaso/llm4scilit/data/
src/demo.py CHANGED
@@ -114,7 +114,7 @@ class App:
114
 
115
  def ask_chat(self, line, history):
116
  # print(f"\nLLM4SciLit: a bunch of nonsense\n")
117
- return self.qa_model.answer_question(line, {})['result']
118
 
119
 
120
  ##################################################################################################
 
114
 
115
  def ask_chat(self, line, history):
116
  # print(f"\nLLM4SciLit: a bunch of nonsense\n")
117
+ return self.qa_model.answer_question(line, {})
118
 
119
 
120
  ##################################################################################################
src/document_retriever/multiquery_retriever.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
2
+ from langchain.retrievers.multi_query import MultiQueryRetriever
3
+
4
+ # Set logging for the queries
5
+ import logging
6
+
7
+ logging.basicConfig()
8
+
9
+
10
+ class MultiQueryDocumentRetriever:
11
+ def __init__(self, vector_store):
12
+ self.vector_store = vector_store
13
+ self.retriever = None
14
+ self.llm = None
15
+ # self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd"
16
+
17
+ def initialize(self):
18
+ # self.llama = LlamaAPI(self.token)
19
+ self.llm = HuggingFacePipeline.from_model_id(
20
+ # model_id="bigscience/bloom-1b7",
21
+ model_id="bigscience/bloomz-1b7",
22
+ task="text-generation",
23
+ # device=1,
24
+ # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
25
+ model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
26
+ # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
27
+ pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
28
+ )
29
+
30
+ logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
31
+ self.retriever = MultiQueryRetriever.from_llm(
32
+ retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}),
33
+ llm=self.llm
34
+ )
35
+
36
+ def retrieve(self, query: str, k: int = 4):
37
+ pass
src/gradio.py DELETED
@@ -1,17 +0,0 @@
1
- import gradio as gr
2
- from hydra import compose, initialize
3
- from omegaconf import OmegaConf
4
-
5
- from demo import App
6
-
7
- def main():
8
- with initialize(version_base=None, config_path="../config", job_name="gradio_app"):
9
- cfg = compose(config_name="config", overrides=["document_parsing.enabled=False"])
10
-
11
- app = App(cfg)
12
-
13
- webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
14
- webapp.launch(share=True)
15
-
16
- if __name__ == "__main__":
17
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/gradio_app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ from omegaconf import DictConfig
3
+ from demo import App
4
+
5
+ from llm4scilit_gradio_interface import LLM4SciLitChatInterface
6
+
7
+ def echo(text, history):
8
+ asdf = "asdf"
9
+ values = [f"{x}\n{x*2}" for x in asdf]
10
+ return text, *values
11
+
12
+
13
+ @hydra.main(version_base=None, config_path="../config", config_name="gradio_config")
14
+ def main(cfg : DictConfig) -> None:
15
+ cfg.document_parsing['enabled'] = False
16
+
17
+ app = App(cfg)
18
+ app._bootstrap()
19
+
20
+ def wrapped_ask_chat(text, history):
21
+ result = app.ask_chat(text, history)
22
+ sources = [
23
+ f"{x.metadata['paper_title']}\n{x.page_content}"
24
+ for x in result['source_documents']
25
+ ]
26
+ return result['result'], *sources
27
+
28
+
29
+ LLM4SciLitChatInterface(wrapped_ask_chat, title="LLM4SciLit").launch()
30
+ # LLM4SciLitChatInterface(echo, title="LLM4SciLit").launch()
31
+
32
+ # textbox = gr.Textbox(placeholder="Ask a question about scientific literature", lines=2, label="Question", elem_id="textbox")
33
+ # chatbot = gr.Chatbot(label="LLM4SciLit", elem_id="chat")
34
+ # gr.Interface(fn=echo, inputs=[textbox, chatbot], outputs=[chatbot], title="LLM4SciLit").launch()
35
+
36
+ # with gr.Blocks() as demo:
37
+ # chatbot = gr.Chatbot()
38
+ # msg = gr.Textbox(container=False)
39
+ # clear = gr.ClearButton([msg, chatbot])
40
+
41
+ # def respond(message, chat_history):
42
+ # bot_message = "How are you?"
43
+ # chat_history.append((message, bot_message))
44
+ # return "", chat_history
45
+
46
+ # msg.submit(respond, [msg, chatbot], [msg, chatbot])
47
+
48
+
49
+
50
+ # with gr.Blocks(title="LLM4SciLit") as demo:
51
+ # with gr.Row():
52
+ # with gr.Column(scale=5):
53
+ # with gr.Row():
54
+ # gr.Chatbot(fn=echo)
55
+ # with gr.Row():
56
+ # gr.Button("Submit")
57
+
58
+ # with gr.Column(scale=5):
59
+ # with gr.Accordion("Retrieved documents"):
60
+ # gr.Label("Document 1")
61
+
62
+ # webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
63
+ # webapp = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
64
+ # demo.launch()
65
+ # webapp.launch(share=True)
66
+
67
+ if __name__ == "__main__":
68
+ main() # pylint: disable=no-value-for-parameter
src/llm4scilit_gradio_interface.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
3
+ """
4
+
5
+
6
+ from __future__ import annotations
7
+
8
+ import inspect
9
+ from typing import AsyncGenerator, Callable
10
+
11
+ import anyio
12
+ from gradio_client import utils as client_utils
13
+ from gradio_client.documentation import document, set_documentation_group
14
+
15
+ from gradio.blocks import Blocks
16
+ from gradio.components import (
17
+ Button,
18
+ Chatbot,
19
+ IOComponent,
20
+ Markdown,
21
+ State,
22
+ Textbox,
23
+ get_component_instance,
24
+ )
25
+ from gradio.events import Dependency, EventListenerMethod, on
26
+ from gradio.helpers import create_examples as Examples # noqa: N812
27
+ from gradio.layouts import Accordion, Column, Group, Row
28
+ from gradio.themes import ThemeClass as Theme
29
+ from gradio.utils import SyncToAsyncIterator, async_iteration
30
+
31
+ set_documentation_group("chatinterface")
32
+
33
+
34
+ @document()
35
+ class LLM4SciLitChatInterface(Blocks):
36
+ """
37
+ ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
38
+ a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
39
+ takes a function that governs the response of the chatbot based on the user input and chat history. Additional
40
+ parameters can be used to control the appearance and behavior of the demo.
41
+
42
+ Example:
43
+ import gradio as gr
44
+
45
+ def echo(message, history):
46
+ return message
47
+
48
+ demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
49
+ demo.launch()
50
+ Demos: chatinterface_random_response, chatinterface_streaming_echo
51
+ Guides: creating-a-chatbot-fast, sharing-your-app
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ fn: Callable,
57
+ *,
58
+ chatbot: Chatbot | None = None,
59
+ textbox: Textbox | None = None,
60
+ additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
61
+ additional_inputs_accordion_name: str = "Additional Inputs",
62
+ examples: list[str] | None = None,
63
+ cache_examples: bool | None = None,
64
+ title: str | None = None,
65
+ description: str | None = None,
66
+ theme: Theme | str | None = None,
67
+ css: str | None = None,
68
+ analytics_enabled: bool | None = None,
69
+ submit_btn: str | None | Button = "Submit",
70
+ stop_btn: str | None | Button = "Stop",
71
+ retry_btn: str | None | Button = "🔄 Retry",
72
+ undo_btn: str | None | Button = "↩️ Undo",
73
+ clear_btn: str | None | Button = "🗑️ Clear",
74
+ autofocus: bool = True,
75
+ ):
76
+ """
77
+ Parameters:
78
+ fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
79
+ chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
80
+ textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
81
+ additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
82
+ additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
83
+ examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
84
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
85
+ title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
86
+ description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
87
+ theme: Theme to use, loaded from gradio.themes.
88
+ css: custom css or path to custom css file to use with interface.
89
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
90
+ submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
91
+ stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
92
+ retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
93
+ undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
94
+ clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
95
+ autofocus: If True, autofocuses to the textbox when the page loads.
96
+ """
97
+ super().__init__(
98
+ analytics_enabled=analytics_enabled,
99
+ mode="chat_interface",
100
+ css=css,
101
+ title=title or "Gradio",
102
+ theme=theme,
103
+ )
104
+ self.fn = fn
105
+ self.is_async = inspect.iscoroutinefunction(
106
+ self.fn
107
+ ) or inspect.isasyncgenfunction(self.fn)
108
+ self.is_generator = inspect.isgeneratorfunction(
109
+ self.fn
110
+ ) or inspect.isasyncgenfunction(self.fn)
111
+ self.examples = examples
112
+ if self.space_id and cache_examples is None:
113
+ self.cache_examples = True
114
+ else:
115
+ self.cache_examples = cache_examples or False
116
+ self.buttons: list[Button] = []
117
+
118
+ if additional_inputs:
119
+ if not isinstance(additional_inputs, list):
120
+ additional_inputs = [additional_inputs]
121
+ self.additional_inputs = [
122
+ get_component_instance(i) for i in additional_inputs # type: ignore
123
+ ]
124
+ else:
125
+ self.additional_inputs = []
126
+ self.additional_inputs_accordion_name = additional_inputs_accordion_name
127
+
128
+ self.additional_outputs = []
129
+
130
+ with self:
131
+ if title:
132
+ Markdown(
133
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
134
+ )
135
+ if description:
136
+ Markdown(description)
137
+
138
+ with Row():
139
+ with Column(variant="panel", scale=1):
140
+ if chatbot:
141
+ self.chatbot = chatbot.render()
142
+ else:
143
+ self.chatbot = Chatbot(label="Chatbot")
144
+
145
+ with Group():
146
+ with Row():
147
+ if textbox:
148
+ textbox.container = False
149
+ textbox.show_label = False
150
+ self.textbox = textbox.render()
151
+ else:
152
+ self.textbox = Textbox(
153
+ container=False,
154
+ show_label=False,
155
+ label="Message",
156
+ placeholder="Type a message...",
157
+ scale=7,
158
+ autofocus=autofocus,
159
+ )
160
+ if submit_btn:
161
+ if isinstance(submit_btn, Button):
162
+ submit_btn.render()
163
+ elif isinstance(submit_btn, str):
164
+ submit_btn = Button(
165
+ submit_btn,
166
+ variant="primary",
167
+ scale=1,
168
+ min_width=150,
169
+ )
170
+ else:
171
+ raise ValueError(
172
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
173
+ )
174
+ if stop_btn:
175
+ if isinstance(stop_btn, Button):
176
+ stop_btn.visible = False
177
+ stop_btn.render()
178
+ elif isinstance(stop_btn, str):
179
+ stop_btn = Button(
180
+ stop_btn,
181
+ variant="stop",
182
+ visible=False,
183
+ scale=1,
184
+ min_width=150,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
189
+ )
190
+ self.buttons.extend([submit_btn, stop_btn])
191
+
192
+ with Row():
193
+ for btn in [retry_btn, undo_btn, clear_btn]:
194
+ if btn:
195
+ if isinstance(btn, Button):
196
+ btn.render()
197
+ elif isinstance(btn, str):
198
+ btn = Button(btn, variant="secondary")
199
+ else:
200
+ raise ValueError(
201
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
202
+ )
203
+ self.buttons.append(btn)
204
+
205
+ self.fake_api_btn = Button("Fake API", visible=False)
206
+ self.fake_response_textbox = Textbox(
207
+ label="Response", visible=False
208
+ )
209
+ (
210
+ self.submit_btn,
211
+ self.stop_btn,
212
+ self.retry_btn,
213
+ self.undo_btn,
214
+ self.clear_btn,
215
+ ) = self.buttons
216
+
217
+ with Column(variant="panel", scale=2):
218
+ for i in range(4):
219
+ self.additional_outputs.append(
220
+ Textbox(
221
+ interactive=False,
222
+ label=f"Document {i+1}"
223
+ )
224
+ )
225
+
226
+ if examples:
227
+ if self.is_generator:
228
+ examples_fn = self._examples_stream_fn
229
+ else:
230
+ examples_fn = self._examples_fn
231
+
232
+ self.examples_handler = Examples(
233
+ examples=examples,
234
+ inputs=[self.textbox] + self.additional_inputs,
235
+ outputs=self.chatbot,
236
+ fn=examples_fn,
237
+ )
238
+
239
+ any_unrendered_inputs = any(
240
+ not inp.is_rendered for inp in self.additional_inputs
241
+ )
242
+ if self.additional_inputs and any_unrendered_inputs:
243
+ with Accordion(self.additional_inputs_accordion_name, open=False):
244
+ for input_component in self.additional_inputs:
245
+ if not input_component.is_rendered:
246
+ input_component.render()
247
+
248
+ # The example caching must happen after the input components have rendered
249
+ if cache_examples:
250
+ client_utils.synchronize_async(self.examples_handler.cache)
251
+
252
+ self.saved_input = State()
253
+ self.chatbot_state = State([])
254
+
255
+ self._setup_events()
256
+ self._setup_api()
257
+
258
+ def _setup_events(self) -> None:
259
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
260
+ submit_triggers = (
261
+ [self.textbox.submit, self.submit_btn.click]
262
+ if self.submit_btn
263
+ else [self.textbox.submit]
264
+ )
265
+ submit_event = (
266
+ on(
267
+ submit_triggers,
268
+ self._clear_and_save_textbox,
269
+ [self.textbox],
270
+ [self.textbox, self.saved_input],
271
+ api_name=False,
272
+ queue=False,
273
+ )
274
+ .then(
275
+ self._display_input,
276
+ [self.saved_input, self.chatbot_state],
277
+ [self.chatbot, self.chatbot_state],
278
+ api_name=False,
279
+ queue=False,
280
+ )
281
+ .then(
282
+ submit_fn,
283
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
284
+ [self.chatbot, self.chatbot_state] + self.additional_outputs,
285
+ api_name=False,
286
+ )
287
+ )
288
+ self._setup_stop_events(submit_triggers, submit_event)
289
+
290
+ if self.retry_btn:
291
+ retry_event = (
292
+ self.retry_btn.click(
293
+ self._delete_prev_fn,
294
+ [self.chatbot_state],
295
+ [self.chatbot, self.saved_input, self.chatbot_state],
296
+ api_name=False,
297
+ queue=False,
298
+ )
299
+ .then(
300
+ self._display_input,
301
+ [self.saved_input, self.chatbot_state],
302
+ [self.chatbot, self.chatbot_state],
303
+ api_name=False,
304
+ queue=False,
305
+ )
306
+ .then(
307
+ submit_fn,
308
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
309
+ [self.chatbot, self.chatbot_state],
310
+ api_name=False,
311
+ )
312
+ )
313
+ self._setup_stop_events([self.retry_btn.click], retry_event)
314
+
315
+ if self.undo_btn:
316
+ self.undo_btn.click(
317
+ self._delete_prev_fn,
318
+ [self.chatbot_state],
319
+ [self.chatbot, self.saved_input, self.chatbot_state],
320
+ api_name=False,
321
+ queue=False,
322
+ ).then(
323
+ lambda x: x,
324
+ [self.saved_input],
325
+ [self.textbox],
326
+ api_name=False,
327
+ queue=False,
328
+ )
329
+
330
+ if self.clear_btn:
331
+ self.clear_btn.click(
332
+ lambda: ([], [], None),
333
+ None,
334
+ [self.chatbot, self.chatbot_state, self.saved_input],
335
+ queue=False,
336
+ api_name=False,
337
+ )
338
+
339
+ def _setup_stop_events(
340
+ self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
341
+ ) -> None:
342
+ if self.stop_btn and self.is_generator:
343
+ if self.submit_btn:
344
+ for event_trigger in event_triggers:
345
+ event_trigger(
346
+ lambda: (
347
+ Button.update(visible=False),
348
+ Button.update(visible=True),
349
+ ),
350
+ None,
351
+ [self.submit_btn, self.stop_btn],
352
+ api_name=False,
353
+ queue=False,
354
+ )
355
+ event_to_cancel.then(
356
+ lambda: (Button.update(visible=True), Button.update(visible=False)),
357
+ None,
358
+ [self.submit_btn, self.stop_btn],
359
+ api_name=False,
360
+ queue=False,
361
+ )
362
+ else:
363
+ for event_trigger in event_triggers:
364
+ event_trigger(
365
+ lambda: Button.update(visible=True),
366
+ None,
367
+ [self.stop_btn],
368
+ api_name=False,
369
+ queue=False,
370
+ )
371
+ event_to_cancel.then(
372
+ lambda: Button.update(visible=False),
373
+ None,
374
+ [self.stop_btn],
375
+ api_name=False,
376
+ queue=False,
377
+ )
378
+ self.stop_btn.click(
379
+ None,
380
+ None,
381
+ None,
382
+ cancels=event_to_cancel,
383
+ api_name=False,
384
+ )
385
+
386
+ def _setup_api(self) -> None:
387
+ api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
388
+
389
+ self.fake_api_btn.click(
390
+ api_fn,
391
+ [self.textbox, self.chatbot_state] + self.additional_inputs,
392
+ [self.textbox, self.chatbot_state],
393
+ api_name="chat",
394
+ )
395
+
396
+ def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
397
+ return "", message
398
+
399
+ def _display_input(
400
+ self, message: str, history: list[list[str | None]]
401
+ ) -> tuple[list[list[str | None]], list[list[str | None]]]:
402
+ history.append([message, None])
403
+ return history, history
404
+
405
+ async def _submit_fn(
406
+ self,
407
+ message: str,
408
+ history_with_input: list[list[str | None]],
409
+ *args,
410
+ ) -> tuple[list[list[str | None]], list[list[str | None]]]:
411
+ history = history_with_input[:-1]
412
+ if self.is_async:
413
+ [response, *other_outputs] = await self.fn(message, history, *args)
414
+ else:
415
+ [response, *other_outputs] = await anyio.to_thread.run_sync(
416
+ self.fn, message, history, *args, limiter=self.limiter
417
+ )
418
+ history.append([message, response])
419
+
420
+ return history, history, *other_outputs
421
+
422
+ async def _stream_fn(
423
+ self,
424
+ message: str,
425
+ history_with_input: list[list[str | None]],
426
+ *args,
427
+ ) -> AsyncGenerator:
428
+ history = history_with_input[:-1]
429
+ if self.is_async:
430
+ generator = self.fn(message, history, *args)
431
+ else:
432
+ generator = await anyio.to_thread.run_sync(
433
+ self.fn, message, history, *args, limiter=self.limiter
434
+ )
435
+ generator = SyncToAsyncIterator(generator, self.limiter)
436
+ try:
437
+ first_response = await async_iteration(generator)
438
+ update = history + [[message, first_response]]
439
+ yield update, update
440
+ except StopIteration:
441
+ update = history + [[message, None]]
442
+ yield update, update
443
+ async for response in generator:
444
+ update = history + [[message, response]]
445
+ yield update, update
446
+
447
+ async def _api_submit_fn(
448
+ self, message: str, history: list[list[str | None]], *args
449
+ ) -> tuple[str, list[list[str | None]]]:
450
+ if self.is_async:
451
+ response = await self.fn(message, history, *args)
452
+ else:
453
+ response = await anyio.to_thread.run_sync(
454
+ self.fn, message, history, *args, limiter=self.limiter
455
+ )
456
+ history.append([message, response])
457
+ return response, history
458
+
459
+ async def _api_stream_fn(
460
+ self, message: str, history: list[list[str | None]], *args
461
+ ) -> AsyncGenerator:
462
+ if self.is_async:
463
+ generator = self.fn(message, history, *args)
464
+ else:
465
+ generator = await anyio.to_thread.run_sync(
466
+ self.fn, message, history, *args, limiter=self.limiter
467
+ )
468
+ generator = SyncToAsyncIterator(generator, self.limiter)
469
+ try:
470
+ first_response = await async_iteration(generator)
471
+ yield first_response, history + [[message, first_response]]
472
+ except StopIteration:
473
+ yield None, history + [[message, None]]
474
+ async for response in generator:
475
+ yield response, history + [[message, response]]
476
+
477
+ async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
478
+ if self.is_async:
479
+ response = await self.fn(message, [], *args)
480
+ else:
481
+ response = await anyio.to_thread.run_sync(
482
+ self.fn, message, [], *args, limiter=self.limiter
483
+ )
484
+ return [[message, response]]
485
+
486
+ async def _examples_stream_fn(
487
+ self,
488
+ message: str,
489
+ *args,
490
+ ) -> AsyncGenerator:
491
+ if self.is_async:
492
+ generator = self.fn(message, [], *args)
493
+ else:
494
+ generator = await anyio.to_thread.run_sync(
495
+ self.fn, message, [], *args, limiter=self.limiter
496
+ )
497
+ generator = SyncToAsyncIterator(generator, self.limiter)
498
+ async for response in generator:
499
+ yield [[message, response]]
500
+
501
+ def _delete_prev_fn(
502
+ self, history: list[list[str | None]]
503
+ ) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
504
+ try:
505
+ message, _ = history.pop()
506
+ except IndexError:
507
+ message = ""
508
+ return history, message or "", history
src/question_answering/huggingface.py CHANGED
@@ -1,15 +1,15 @@
1
- from langchain import PromptTemplate
2
  from langchain.chains import RetrievalQA
3
- from langchain.llms import HuggingFacePipeline
4
 
5
  class HuggingFaceQuestionAnswering:
6
  def __init__(self, retriever) -> None:
7
  self.retriever = retriever
8
  self.llm = HuggingFacePipeline.from_model_id(
9
  # model_id="bigscience/bloom-1b7",
10
- model_id="bigscience/bloomz-1b1",
11
  task="text-generation",
12
- device=1,
13
  # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
14
  model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
15
  # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
@@ -27,6 +27,7 @@ class HuggingFaceQuestionAnswering:
27
 
28
  def answer_question(self, question: str, filter_dict):
29
  retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
 
30
 
31
  try:
32
  self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
@@ -36,5 +37,6 @@ class HuggingFaceQuestionAnswering:
36
  Retrieved Documents:
37
  {docs if docs != "" else "No documents found."}""")
38
  return result
39
- except:
 
40
  return {"result": "Error generating answer."}
 
1
+ from langchain.prompts.prompt import PromptTemplate
2
  from langchain.chains import RetrievalQA
3
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
4
 
5
  class HuggingFaceQuestionAnswering:
6
  def __init__(self, retriever) -> None:
7
  self.retriever = retriever
8
  self.llm = HuggingFacePipeline.from_model_id(
9
  # model_id="bigscience/bloom-1b7",
10
+ model_id="bigscience/bloomz-1b7",
11
  task="text-generation",
12
+ # device=1,
13
  # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
14
  model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
15
  # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
 
27
 
28
  def answer_question(self, question: str, filter_dict):
29
  retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
30
+ # retriever = self.retriever.retriever
31
 
32
  try:
33
  self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
 
37
  Retrieved Documents:
38
  {docs if docs != "" else "No documents found."}""")
39
  return result
40
+ except Exception as e:
41
+ print(e)
42
  return {"result": "Error generating answer."}