terapyon commited on
Commit
2f682e6
1 Parent(s): 1f4ac35
Files changed (1) hide show
  1. app.py +70 -4
app.py CHANGED
@@ -3,6 +3,8 @@ from datetime import datetime, date, timedelta
3
  from typing import Iterable
4
  import streamlit as st
5
  import torch
 
 
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.vectorstores import Qdrant
8
  from qdrant_client import QdrantClient
@@ -33,8 +35,44 @@ def llm_model(model="gpt-3.5-turbo", temperature=0.2):
33
  return llm
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  EMBEDDINGS = load_embeddings()
37
  LLM = llm_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def make_filter_obj(options: list[dict[str]]):
@@ -78,7 +116,7 @@ def get_similay(query: str, filter: Filter):
78
  return docs
79
 
80
 
81
- def get_retrieval_qa(filter: Filter):
82
  db_url, db_api_key, db_collection_name = DB_CONFIG
83
  client = QdrantClient(url=db_url, api_key=db_api_key)
84
  db = Qdrant(
@@ -90,7 +128,7 @@ def get_retrieval_qa(filter: Filter):
90
  }
91
  )
92
  result = RetrievalQA.from_chain_type(
93
- llm=LLM,
94
  chain_type="stuff",
95
  retriever=retriever,
96
  return_source_documents=True,
@@ -143,6 +181,7 @@ def _get_query_str_filter(
143
 
144
 
145
  def run_qa(
 
146
  query: str,
147
  repo_name: str,
148
  query_options: str,
@@ -154,7 +193,7 @@ def run_qa(
154
  query_str, filter = _get_query_str_filter(
155
  query, repo_name, query_options, start_date, end_date, include_comments
156
  )
157
- qa = get_retrieval_qa(filter)
158
  try:
159
  result = qa(query_str)
160
  except InvalidRequestError as e:
@@ -271,10 +310,37 @@ with st.form("my_form"):
271
  st.divider()
272
  with st.spinner("QA Searching..."):
273
  results = run_qa(
274
- query, repo_name, query_options, start_date, end_date, include_comments
 
 
 
 
 
 
275
  )
276
  answer, html = results
277
  with st.container():
278
  st.write(answer)
279
  st.markdown(html, unsafe_allow_html=True)
280
  st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import Iterable
4
  import streamlit as st
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from langchain.llms import HuggingFacePipeline
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import Qdrant
10
  from qdrant_client import QdrantClient
 
35
  return llm
36
 
37
 
38
+ @st.cache_resource
39
+ def load_vicuna_model():
40
+ if torch.cuda.is_available():
41
+ model_name = "lmsys/vicuna-13b-v1.5"
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ load_in_8bit=True,
46
+ torch_dtype=torch.float16,
47
+ device_map="auto",
48
+ )
49
+ return tokenizer, model
50
+ else:
51
+ return None, None
52
+
53
+
54
  EMBEDDINGS = load_embeddings()
55
  LLM = llm_model()
56
+ VICUNA_TOKENIZER, VICUNA_MODEL = load_vicuna_model()
57
+
58
+
59
+ @st.cache_resource
60
+ def _get_vicuna_llm(temperature=0.2) -> HuggingFacePipeline | None:
61
+ if VICUNA_MODEL is not None:
62
+ pipe = pipeline(
63
+ "text-generation",
64
+ model=VICUNA_MODEL,
65
+ tokenizer=VICUNA_TOKENIZER,
66
+ max_new_tokens=1024,
67
+ temperature=temperature,
68
+ )
69
+ llm = HuggingFacePipeline(pipeline=pipe)
70
+ else:
71
+ llm = None
72
+ return llm
73
+
74
+
75
+ VICUNA_LLM = _get_vicuna_llm()
76
 
77
 
78
  def make_filter_obj(options: list[dict[str]]):
 
116
  return docs
117
 
118
 
119
+ def get_retrieval_qa(filter: Filter, llm):
120
  db_url, db_api_key, db_collection_name = DB_CONFIG
121
  client = QdrantClient(url=db_url, api_key=db_api_key)
122
  db = Qdrant(
 
128
  }
129
  )
130
  result = RetrievalQA.from_chain_type(
131
+ llm=llm,
132
  chain_type="stuff",
133
  retriever=retriever,
134
  return_source_documents=True,
 
181
 
182
 
183
  def run_qa(
184
+ llm,
185
  query: str,
186
  repo_name: str,
187
  query_options: str,
 
193
  query_str, filter = _get_query_str_filter(
194
  query, repo_name, query_options, start_date, end_date, include_comments
195
  )
196
+ qa = get_retrieval_qa(filter, llm)
197
  try:
198
  result = qa(query_str)
199
  except InvalidRequestError as e:
 
310
  st.divider()
311
  with st.spinner("QA Searching..."):
312
  results = run_qa(
313
+ LLM,
314
+ query,
315
+ repo_name,
316
+ query_options,
317
+ start_date,
318
+ end_date,
319
+ include_comments,
320
  )
321
  answer, html = results
322
  with st.container():
323
  st.write(answer)
324
  st.markdown(html, unsafe_allow_html=True)
325
  st.divider()
326
+ if torch.cuda.is_available():
327
+ qa_searched_vicuna = submit_col2.form_submit_button("QA Search by Vicuna")
328
+ if qa_searched_vicuna:
329
+ st.divider()
330
+ st.header("QA Search Results by Vicuna-13b-v1.5")
331
+ st.divider()
332
+ with st.spinner("QA Searching..."):
333
+ results = run_qa(
334
+ VICUNA_LLM,
335
+ query,
336
+ repo_name,
337
+ query_options,
338
+ start_date,
339
+ end_date,
340
+ include_comments,
341
+ )
342
+ answer, html = results
343
+ with st.container():
344
+ st.write(answer)
345
+ st.markdown(html, unsafe_allow_html=True)
346
+ st.divider()