shimizukawa commited on
Commit
c56ab56
โ€ข
1 Parent(s): 2f682e6

initial modify

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +33 -121
  3. config.py +2 -2
  4. gh_issue_loader.py โ†’ doc_loader.py +22 -31
  5. model.py +2 -4
  6. store.py +7 -7
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Github Issue Search
3
  emoji: ๐Ÿ 
4
  colorFrom: green
5
  colorTo: purple
 
1
  ---
2
+ title: Document Search
3
  emoji: ๐Ÿ 
4
  colorFrom: green
5
  colorTo: purple
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from time import time
2
- from datetime import datetime, date, timedelta
3
  from typing import Iterable
4
  import streamlit as st
5
  import torch
@@ -13,7 +13,7 @@ from langchain.chains import RetrievalQA
13
  from openai.error import InvalidRequestError
14
  from langchain.chat_models import ChatOpenAI
15
  from config import DB_CONFIG
16
- from model import Issue
17
 
18
 
19
  @st.cache_resource
@@ -108,12 +108,12 @@ def get_similay(query: str, filter: Filter):
108
  db = Qdrant(
109
  client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS
110
  )
111
- docs = db.similarity_search_with_score(
112
  query,
113
  k=20,
114
  filter=filter,
115
  )
116
- return docs
117
 
118
 
119
  def get_retrieval_qa(filter: Filter, llm):
@@ -150,49 +150,20 @@ def _get_related_url(metadata) -> Iterable[str]:
150
 
151
  def _get_query_str_filter(
152
  query: str,
153
- repo_name: str,
154
- query_options: str,
155
- start_date: date,
156
- end_date: date,
157
- include_comments: bool,
158
  ) -> tuple[str, Filter]:
159
- options = [{"key": "metadata.repo_name", "value": repo_name}]
160
- if start_date is not None and end_date is not None:
161
- options.append(
162
- {
163
- "key": "metadata.created_at",
164
- "range": {
165
- "gte": int(datetime.fromisoformat(str(start_date)).timestamp()),
166
- "lte": int(
167
- datetime.fromisoformat(
168
- str(end_date + timedelta(days=1))
169
- ).timestamp()
170
- ),
171
- },
172
- }
173
- )
174
- if not include_comments:
175
- options.append({"key": "metadata.type_", "value": "issue"})
176
  filter = make_filter_obj(options=options)
177
- if query_options == "Empty":
178
- query_options = ""
179
- query_str = f"{query_options}{query}"
180
- return query_str, filter
181
 
182
 
183
  def run_qa(
184
  llm,
185
  query: str,
186
- repo_name: str,
187
- query_options: str,
188
- start_date: date,
189
- end_date: date,
190
- include_comments: bool,
191
  ) -> tuple[str, str]:
192
  now = time()
193
- query_str, filter = _get_query_str_filter(
194
- query, repo_name, query_options, start_date, end_date, include_comments
195
- )
196
  qa = get_retrieval_qa(filter, llm)
197
  try:
198
  result = qa(query_str)
@@ -207,71 +178,29 @@ def run_qa(
207
 
208
  def run_search(
209
  query: str,
210
- repo_name: str,
211
- query_options: str,
212
- start_date: date,
213
- end_date: date,
214
- include_comments: bool,
215
- ) -> Iterable[tuple[Issue, float, str]]:
216
- query_str, filter = _get_query_str_filter(
217
- query, repo_name, query_options, start_date, end_date, include_comments
218
- )
219
- docs = get_similay(query_str, filter)
220
- for doc, score in docs:
221
- text = doc.page_content
222
- metadata = doc.metadata
223
  # print(metadata)
224
- issue = Issue(
225
- repo_name=repo_name,
226
  id=metadata.get("id"),
227
  title=metadata.get("title"),
228
- created_at=metadata.get("created_at"),
229
  user=metadata.get("user"),
230
  url=metadata.get("url"),
231
- labels=metadata.get("labels"),
232
- type_=metadata.get("type_"),
233
  )
234
- yield issue, score, text
235
 
236
 
237
  with st.form("my_form"):
238
- st.title("GitHub Issue Search")
239
  query = st.text_input(label="query")
240
- repo_name = st.radio(
241
- options=[
242
- "cpython",
243
- "pyvista",
244
- "plone",
245
- "volto",
246
- "plone.restapi",
247
- "nvda",
248
- "nvdajp",
249
- "cocoa",
250
- ],
251
- label="Repo name",
252
- )
253
- query_options = st.radio(
254
- options=[
255
- "query: ",
256
- "query: passage: ",
257
- "Empty",
258
- ],
259
- label="Query options",
260
- )
261
- date_min = date(2022, 1, 1)
262
- date_max = date.today()
263
- date_col1, date_col2 = st.columns(2)
264
- start_date = date_col1.date_input(
265
- label="Select a start date",
266
- value=date_min,
267
- format="YYYY-MM-DD",
268
- )
269
- end_date = date_col2.date_input(
270
- label="Select a end date",
271
- value=date_max,
272
- format="YYYY-MM-DD",
273
- )
274
- include_comments = st.checkbox(label="Include Issue comments", value=True)
275
 
276
  submit_col1, submit_col2 = st.columns(2)
277
  searched = submit_col1.form_submit_button("Search")
@@ -280,28 +209,19 @@ with st.form("my_form"):
280
  st.header("Search Results")
281
  st.divider()
282
  with st.spinner("Searching..."):
283
- results = run_search(
284
- query, repo_name, query_options, start_date, end_date, include_comments
285
- )
286
- for issue, score, text in results:
287
- title = issue.title
288
- url = issue.url
289
- id_ = issue.id
290
  score = round(score, 3)
291
- created_at = datetime.fromtimestamp(issue.created_at)
292
- user = issue.user
293
- labels = issue.labels
294
- is_comment = issue.type_ == "comment"
295
  with st.container():
296
- if not is_comment:
297
- st.subheader(f"#{id_} - {title}")
298
- else:
299
- st.subheader(f"comment with {title}")
300
  st.write(url)
301
  st.write(text)
302
- st.write("score:", score, "Date:", created_at.date(), "User:", user)
303
- st.write(f"{labels=}")
304
- # st.markdown(html, unsafe_allow_html=True)
305
  st.divider()
306
  qa_searched = submit_col2.form_submit_button("QA Search by OpenAI")
307
  if qa_searched:
@@ -312,11 +232,7 @@ with st.form("my_form"):
312
  results = run_qa(
313
  LLM,
314
  query,
315
- repo_name,
316
- query_options,
317
- start_date,
318
- end_date,
319
- include_comments,
320
  )
321
  answer, html = results
322
  with st.container():
@@ -333,11 +249,7 @@ with st.form("my_form"):
333
  results = run_qa(
334
  VICUNA_LLM,
335
  query,
336
- repo_name,
337
- query_options,
338
- start_date,
339
- end_date,
340
- include_comments,
341
  )
342
  answer, html = results
343
  with st.container():
 
1
  from time import time
2
+ from datetime import datetime
3
  from typing import Iterable
4
  import streamlit as st
5
  import torch
 
13
  from openai.error import InvalidRequestError
14
  from langchain.chat_models import ChatOpenAI
15
  from config import DB_CONFIG
16
+ from model import Doc
17
 
18
 
19
  @st.cache_resource
 
108
  db = Qdrant(
109
  client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS
110
  )
111
+ qdocs = db.similarity_search_with_score(
112
  query,
113
  k=20,
114
  filter=filter,
115
  )
116
+ return qdocs
117
 
118
 
119
  def get_retrieval_qa(filter: Filter, llm):
 
150
 
151
  def _get_query_str_filter(
152
  query: str,
153
+ project_name: str,
 
 
 
 
154
  ) -> tuple[str, Filter]:
155
+ options = [{"key": "metadata.project_name", "value": project_name}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  filter = make_filter_obj(options=options)
157
+ return query, filter
 
 
 
158
 
159
 
160
  def run_qa(
161
  llm,
162
  query: str,
163
+ project_name: str,
 
 
 
 
164
  ) -> tuple[str, str]:
165
  now = time()
166
+ query_str, filter = _get_query_str_filter(query, project_name)
 
 
167
  qa = get_retrieval_qa(filter, llm)
168
  try:
169
  result = qa(query_str)
 
178
 
179
  def run_search(
180
  query: str,
181
+ project_name: str,
182
+ ) -> Iterable[tuple[Doc, float, str]]:
183
+ query_str, filter = _get_query_str_filter(query, project_name)
184
+ qdocs = get_similay(query_str, filter)
185
+ for qdoc, score in qdocs:
186
+ text = qdoc.page_content
187
+ metadata = qdoc.metadata
 
 
 
 
 
 
188
  # print(metadata)
189
+ doc = Doc(
190
+ project_name=project_name,
191
  id=metadata.get("id"),
192
  title=metadata.get("title"),
193
+ ctime=metadata.get("ctime"),
194
  user=metadata.get("user"),
195
  url=metadata.get("url"),
 
 
196
  )
197
+ yield doc, score, text
198
 
199
 
200
  with st.form("my_form"):
201
+ st.title("Document Search")
202
  query = st.text_input(label="query")
203
+ project_name = st.text_input(label="project")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  submit_col1, submit_col2 = st.columns(2)
206
  searched = submit_col1.form_submit_button("Search")
 
209
  st.header("Search Results")
210
  st.divider()
211
  with st.spinner("Searching..."):
212
+ results = run_search(query, project_name)
213
+ for doc, score, text in results:
214
+ title = doc.title
215
+ url = doc.url
216
+ id_ = doc.id
 
 
217
  score = round(score, 3)
218
+ ctime = datetime.fromtimestamp(doc.ctime)
219
+ user = doc.user
 
 
220
  with st.container():
221
+ st.subheader(f"#{id_} - {title}")
 
 
 
222
  st.write(url)
223
  st.write(text)
224
+ st.write("score:", score, "Date:", ctime.date(), "User:", user)
 
 
225
  st.divider()
226
  qa_searched = submit_col2.form_submit_button("QA Search by OpenAI")
227
  if qa_searched:
 
232
  results = run_qa(
233
  LLM,
234
  query,
235
+ project_name,
 
 
 
 
236
  )
237
  answer, html = results
238
  with st.container():
 
249
  results = run_qa(
250
  VICUNA_LLM,
251
  query,
252
+ project_name,
 
 
 
 
253
  )
254
  answer, html = results
255
  with st.container():
config.py CHANGED
@@ -7,14 +7,14 @@ SAAS = True
7
  def get_db_config():
8
  url = os.environ["QDRANT_URL"]
9
  api_key = os.environ["QDRANT_API_KEY"]
10
- collection_name = "gh-issue-search"
11
  return url, api_key, collection_name
12
 
13
 
14
  def get_local_db_congin():
15
  url = "localhost"
16
  # api_key = os.environ["QDRANT_API_KEY"]
17
- collection_name = "gh-issues"
18
  return url, None, collection_name
19
 
20
 
 
7
  def get_db_config():
8
  url = os.environ["QDRANT_URL"]
9
  api_key = os.environ["QDRANT_API_KEY"]
10
+ collection_name = "document-search"
11
  return url, api_key, collection_name
12
 
13
 
14
  def get_local_db_congin():
15
  url = "localhost"
16
  # api_key = os.environ["QDRANT_API_KEY"]
17
+ collection_name = "document-search"
18
  return url, None, collection_name
19
 
20
 
gh_issue_loader.py โ†’ doc_loader.py RENAMED
@@ -4,7 +4,8 @@ from typing import Iterator
4
  from dateutil.parser import parse
5
  from langchain.docstore.document import Document
6
  from langchain.document_loaders.base import BaseLoader
7
- from gh_issue_loader import Issue
 
8
 
9
 
10
  def date_to_int(dt_str: str) -> int:
@@ -12,49 +13,39 @@ def date_to_int(dt_str: str) -> int:
12
  return int(dt.timestamp())
13
 
14
 
15
- def get_contents(repo_name: str, filename: str) -> Iterator[tuple[Issue, str]]:
 
 
 
 
 
16
  with open(filename, "r") as f:
17
  obj = [json.loads(line) for line in f]
18
  for data in obj:
19
  title = data["title"]
20
  body = data["body"]
21
- issue = Issue(
22
- repo_name=repo_name,
23
- id=data["number"],
24
  title=title,
25
- created_at=date_to_int(data["created_at"]),
26
- user=data["user.login"],
27
- url=data["html_url"],
28
- labels=data["labels_"],
29
- type_="issue",
30
  )
31
  text = title
32
  if body:
33
  text += "\n\n" + body
34
- yield issue, text
35
- comments = data["comments_"]
36
- for comment in comments:
37
- issue = Issue(
38
- repo_name=repo_name,
39
- id=comment["id"],
40
- title=data["title"],
41
- created_at=date_to_int(comment["created_at"]),
42
- user=comment["user.login"],
43
- url=comment["html_url"],
44
- labels=data["labels_"],
45
- type_="comment",
46
- )
47
- yield issue, comment["body"]
48
-
49
-
50
- class GHLoader(BaseLoader):
51
- def __init__(self, repo_name: str, filename: str):
52
- self.repo_name = repo_name
53
  self.filename = filename
54
 
55
  def lazy_load(self) -> Iterator[Document]:
56
- for issue, text in get_contents(self.repo_name, self.filename):
57
- metadata = asdict(issue)
58
  yield Document(page_content=text, metadata=metadata)
59
 
60
  def load(self) -> list[Document]:
 
4
  from dateutil.parser import parse
5
  from langchain.docstore.document import Document
6
  from langchain.document_loaders.base import BaseLoader
7
+
8
+ from model import Doc
9
 
10
 
11
  def date_to_int(dt_str: str) -> int:
 
13
  return int(dt.timestamp())
14
 
15
 
16
+ def get_contents(project_name: str, filename: str) -> Iterator[tuple[Doc, str]]:
17
+ """filename for file with ndjson
18
+
19
+ {"title": <page title>, "body": <page body>, "id": <page_id>, "ctime": ..., "user": <name>, "url": "https:..."}
20
+ {"title": ...}
21
+ """
22
  with open(filename, "r") as f:
23
  obj = [json.loads(line) for line in f]
24
  for data in obj:
25
  title = data["title"]
26
  body = data["body"]
27
+ doc = Doc(
28
+ project_name=project_name,
29
+ id=data["id"],
30
  title=title,
31
+ created_at=date_to_int(data["ctime"]),
32
+ user=data["user"],
33
+ url=data["url"],
 
 
34
  )
35
  text = title
36
  if body:
37
  text += "\n\n" + body
38
+ yield doc, text
39
+
40
+
41
+ class DocLoader(BaseLoader):
42
+ def __init__(self, project_name: str, filename: str):
43
+ self.project_name = project_name
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  self.filename = filename
45
 
46
  def lazy_load(self) -> Iterator[Document]:
47
+ for doc, text in get_contents(self.project_name, self.filename):
48
+ metadata = asdict(doc)
49
  yield Document(page_content=text, metadata=metadata)
50
 
51
  def load(self) -> list[Document]:
model.py CHANGED
@@ -2,12 +2,10 @@ from dataclasses import dataclass
2
 
3
 
4
  @dataclass(frozen=True)
5
- class Issue:
6
- repo_name: str
7
  id: int
8
  title: str
9
  created_at: int
10
  user: str
11
  url: str
12
- labels: list[str]
13
- type_: str
 
2
 
3
 
4
  @dataclass(frozen=True)
5
+ class Doc:
6
+ project_name: str
7
  id: int
8
  title: str
9
  created_at: int
10
  user: str
11
  url: str
 
 
store.py CHANGED
@@ -2,7 +2,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import Qdrant
4
 
5
- from gh_issue_loader import GHLoader
6
  from config import DB_CONFIG
7
 
8
 
@@ -36,8 +36,8 @@ def store(texts):
36
  )
37
 
38
 
39
- def main(repo_name: str, path: str) -> None:
40
- loader = GHLoader(repo_name, path)
41
  docs = loader.load()
42
  texts = get_text_chunk(docs)
43
  store(texts)
@@ -45,8 +45,8 @@ def main(repo_name: str, path: str) -> None:
45
 
46
  if __name__ == "__main__":
47
  """
48
- $ python store.py "REPO_NAME" "FILE_PATH"
49
- $ python store.py cocoa data/cocoa-issues.json
50
  """
51
  import sys
52
 
@@ -54,6 +54,6 @@ if __name__ == "__main__":
54
  if len(args) != 3:
55
  print("No args, you need two args for repo_name, json_file_path")
56
  else:
57
- repo_name = args[1]
58
  path = args[2]
59
- main(repo_name, path)
 
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import Qdrant
4
 
5
+ from doc_loader import DocLoader
6
  from config import DB_CONFIG
7
 
8
 
 
36
  )
37
 
38
 
39
+ def main(project_name: str, path: str) -> None:
40
+ loader = DocLoader(project_name, path)
41
  docs = loader.load()
42
  texts = get_text_chunk(docs)
43
  store(texts)
 
45
 
46
  if __name__ == "__main__":
47
  """
48
+ $ python store.py "PROJECT_NAME" "FILE_PATH"
49
+ $ python store.py hoge data/hoge-docs.json
50
  """
51
  import sys
52
 
 
54
  if len(args) != 3:
55
  print("No args, you need two args for repo_name, json_file_path")
56
  else:
57
+ project_name = args[1]
58
  path = args[2]
59
+ main(project_name, path)