hoshingakag commited on
Commit
4919a07
1 Parent(s): 5927179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -415
app.py CHANGED
@@ -1,434 +1,97 @@
1
- import os
2
- import datetime
3
- import asyncio
4
 
5
- from typing import Any, List, Dict, Union
6
- from pydantic import Extra
7
-
8
- import wandb
9
- from wandb.sdk.data_types.trace_tree import Trace
10
-
11
- import pinecone
12
- import google.generativeai as genai
13
-
14
- from llama_index import (
15
- ServiceContext,
16
- PromptHelper,
17
- VectorStoreIndex
18
- )
19
- from llama_index.vector_stores import PineconeVectorStore
20
- from llama_index.storage.storage_context import StorageContext
21
- from llama_index.node_parser import SimpleNodeParser
22
- from llama_index.text_splitter import TokenTextSplitter
23
- from llama_index.embeddings.base import BaseEmbedding
24
- from llama_index.llms import (
25
- CustomLLM,
26
- CompletionResponse,
27
- CompletionResponseGen,
28
- LLMMetadata,
29
- )
30
- from llama_index.llms.base import llm_completion_callback
31
-
32
- from llama_index.evaluation import SemanticSimilarityEvaluator
33
- from llama_index.embeddings import SimilarityMode
34
 
 
 
35
  import logging
36
- logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
37
- logger = logging.getLogger('llm')
38
-
39
- prompt_template = """
40
- [System]
41
- You are in a role play of Gerard Lee.
42
- Reply in no more than 7 complete sentences using content from [Context] only. Refer to [History] for seamless conversatation.
43
-
44
- [History]
45
- {context_history}
46
 
47
- [Context]
48
- {context_from_index}
49
- """
 
50
 
51
- class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
52
- def __init__(
53
- self,
54
- model_name: str = 'models/embedding-gecko-001',
55
- **kwargs: Any,
56
- ) -> None:
57
- super().__init__(**kwargs)
58
- self._model_name = model_name
59
 
60
- @classmethod
61
- def class_name(cls) -> str:
62
- return 'PaLMEmbeddings'
 
63
 
64
- def gen_embeddings(self, text: str) -> List[float]:
65
- return genai.generate_embeddings(self._model_name, text)
 
 
 
 
 
66
 
67
- def _get_query_embedding(self, query: str) -> List[float]:
68
- embeddings = self.gen_embeddings(query)
69
- return embeddings['embedding']
70
 
71
- def _get_text_embedding(self, text: str) -> List[float]:
72
- embeddings = self.gen_embeddings(text)
73
- return embeddings['embedding']
74
 
75
- def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
76
- embeddings = [
77
- self.gen_embeddings(text)['embedding'] for text in texts
78
- ]
79
- return embeddings
80
 
81
- async def _aget_query_embedding(self, query: str) -> List[float]:
82
- return self._get_query_embedding(query)
83
-
84
- async def _aget_text_embedding(self, text: str) -> List[float]:
85
- return self._get_text_embedding(text)
86
-
87
- class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
88
- def __init__(
89
- self,
90
- model_name: str = 'models/text-bison-001',
91
- model_kwargs: dict = {},
92
- context_window: int = 8196,
93
- num_output: int = 1024,
94
- **kwargs: Any,
95
- ) -> None:
96
- super().__init__(**kwargs)
97
- self._model_name = model_name
98
- self._model_kwargs = model_kwargs
99
- self._context_window = context_window
100
- self._num_output = num_output
101
-
102
- @property
103
- def metadata(self) -> LLMMetadata:
104
- """Get LLM metadata."""
105
- return LLMMetadata(
106
- context_window=self._context_window,
107
- num_output=self._num_output,
108
- model_name=self._model_name
109
- )
110
-
111
- def gen_texts(self, prompt):
112
- logging.debug(f"prompt: {prompt}")
113
- response = genai.generate_text(
114
- model=self._model_name,
115
- prompt=prompt,
116
- safety_settings=[
117
- {
118
- 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
119
- 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
120
- },
121
- ],
122
- **self._model_kwargs
123
  )
124
- logging.debug(f"response:\n{response}")
125
- return response.candidates[0]['output']
126
-
127
- @llm_completion_callback()
128
- def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
129
- text = self.gen_texts(prompt)
130
- return CompletionResponse(text=text)
131
-
132
- @llm_completion_callback()
133
- def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
134
- raise NotImplementedError()
135
-
136
- class LlamaIndexPaLM():
137
- def __init__(
138
- self,
139
- emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
140
- model: LlamaIndexPaLMText = LlamaIndexPaLMText(),
141
- # prompt_template: str = prompt_template
142
- ) -> None:
143
- self.emb_model = emb_model
144
- self.llm = model
145
- self.prompt_template = prompt_template
146
-
147
- # Google Generative AI
148
- genai.configure(api_key=os.environ['PALM_API_KEY'])
149
-
150
- # Pinecone
151
- pinecone.init(
152
- api_key=os.environ['PINECONE_API_KEY'],
153
- environment=os.getenv('PINECONE_ENV')
154
- )
155
-
156
- # W&B
157
- wandb.init(project=os.getenv('WANDB_PROJECT'))
158
-
159
- # model metadata
160
- CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
161
- NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
162
- TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
163
- TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
164
- TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
165
- TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
166
-
167
- self.node_parser = SimpleNodeParser.from_defaults(
168
- text_splitter=TokenTextSplitter(
169
- chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
170
  )
171
- )
172
-
173
- self.prompt_helper = PromptHelper(
174
- context_window=CONTEXT_WINDOW,
175
- num_output=NUM_OUTPUT,
176
- chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
177
- chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
178
- )
179
-
180
- self.service_context = ServiceContext.from_defaults(
181
- llm=self.llm,
182
- embed_model=self.emb_model,
183
- node_parser=self.node_parser,
184
- prompt_helper=self.prompt_helper,
185
- )
186
-
187
- self.emd_evaluator = SemanticSimilarityEvaluator(
188
- service_context=self.service_context,
189
- similarity_mode=SimilarityMode.DEFAULT,
190
- similarity_threshold=os.getenv('SIMILARITY_THRESHOLD', 0.7),
191
- )
192
-
193
- def get_index_from_pinecone(
194
- self,
195
- index_name: str = os.getenv('PINECONE_INDEX'),
196
- index_namespace: str = os.getenv('PINECONE_NAMESPACE')
197
- ) -> None:
198
- # Pinecone VectorStore
199
- pinecone_index = pinecone.Index(index_name)
200
- self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True, namespace=index_namespace)
201
- self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
202
- self._index_name = index_name
203
- self._index_namespace = index_namespace
204
- return None
205
-
206
- async def retrieve_context(
207
- self,
208
- query: str
209
- ) -> Dict[str, Union[str, int]]:
210
- start_time = round(datetime.datetime.now().timestamp() * 1000)
211
- response = await self.pinecone_index.as_query_engine(similarity_top_k=3).query(query)
212
- end_time = round(datetime.datetime.now().timestamp() * 1000)
213
- return {"result": response.response, "start": start_time, "end": end_time}
214
-
215
- async def evaluate_context(
216
- self,
217
- query: str,
218
- returned_context: str
219
- ) -> float:
220
- result = await self.emd_evaluator.aevaluate(
221
- response=returned_context,
222
- reference=query,
223
- )
224
- return float(result.score)
225
-
226
- def format_history_as_context(
227
- self,
228
- history: List[str],
229
- ) -> str:
230
- format_chat_history = "\n".join(list(filter(None, history)))
231
- return format_chat_history
232
-
233
- async def generate_text(
234
- self,
235
- query: str,
236
- history: List[str],
237
- ) -> str:
238
- # get history
239
- context_history = self.format_history_as_context(history=history)
240
-
241
- # w&b trace start
242
- start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
243
- root_span = Trace(
244
- name="MetaAgent",
245
- kind="agent",
246
- start_time_ms=start_time_ms,
247
- metadata={"user": "🤗 Space"},
248
- )
249
-
250
- # get retrieval context(s) from llama-index vectorstore index
251
- # w&b trace retrieval & select agent
252
- agent_span = Trace(
253
- name="LlamaIndexAgent",
254
- kind="agent",
255
- start_time_ms=start_time_ms,
256
- )
257
- try:
258
- # No history, single context retrieval without evaluation
259
- if not history:
260
- # w&b trace retrieval context
261
- context_from_index_selected = self.retrieve_context(query)
262
- agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
263
- retrieval_span = Trace(
264
- name="QueryRetrieval",
265
- kind="chain",
266
- status_code="success",
267
- metadata={
268
- "framework": "Llama-Index",
269
- "index_type": "VectorStoreIndex",
270
- "vector_store": "Pinecone",
271
- "vector_store_index": self._index_name,
272
- "vector_store_namespace": self._index_namespace,
273
- "model_name": self.llm._model_name,
274
- "custom_kwargs": self.llm._model_kwargs,
275
- },
276
- start_time_ms=start_time_ms,
277
- end_time_ms=agent_end_time_ms,
278
- inputs={"query": query},
279
- outputs={"response": context_from_index_selected},
280
- )
281
- agent_span.add_child(retrieval_span)
282
- # Has history, multiple context retrieval with async, then evaluation to determine which context to choose
283
- else:
284
- extended_query = f"{history[-1]}\n{query}"
285
- result_query_only, result_extended_query = await asyncio.gather(
286
- self.retrieve_context(query),
287
- self.retrieve_context(extended_query)
288
  )
289
 
290
- # w&b trace retrieval context query only
291
- retrieval_query_span = Trace(
292
- name="QueryRetrieval",
293
- kind="chain",
294
- status_code="success",
295
- metadata={
296
- "framework": "Llama-Index",
297
- "index_type": "VectorStoreIndex",
298
- "vector_store": "Pinecone",
299
- "vector_store_index": self._index_name,
300
- "vector_store_namespace": self._index_namespace,
301
- "model_name": self.llm._model_name,
302
- "custom_kwargs": self.llm._model_kwargs,
303
- },
304
- start_time_ms=result_query_only.start,
305
- end_time_ms=result_query_only.end,
306
- inputs={"query": query},
307
- outputs={"response": result_query_only.result},
308
- )
309
- agent_span.add_child(retrieval_query_span)
310
 
311
- # w&b trace retrieval context extended query
312
- retrieval_extended_query_span = Trace(
313
- name="ExtendedQueryRetrieval",
314
- kind="chain",
315
- status_code="success",
316
- metadata={
317
- "framework": "Llama-Index",
318
- "index_type": "VectorStoreIndex",
319
- "vector_store": "Pinecone",
320
- "vector_store_index": self._index_name,
321
- "vector_store_namespace": self._index_namespace,
322
- "model_name": self.llm._model_name,
323
- "custom_kwargs": self.llm._model_kwargs,
324
- },
325
- start_time_ms=result_extended_query.start,
326
- end_time_ms=result_extended_query.end,
327
- inputs={"query": extended_query},
328
- outputs={"response": result_extended_query.result},
329
- )
330
- agent_span.add_child(retrieval_extended_query_span)
331
-
332
- # w&b trace select context
333
- eval_start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
334
- eval_context_query_only, eval_context_extended_query = await asyncio.gather(
335
- self.evaluate_context(query, result_query_only.result),
336
- self.evaluate_context(extended_query, result_extended_query.result)
337
- )
338
 
339
- if eval_context_query_only > eval_context_extended_query:
340
- query_selected, context_from_index_selected = query, result_query_only.result
341
- else:
342
- query_selected, context_from_index_selected = extended_query, result_extended_query.result
343
-
344
- agent_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
345
- eval_span = Trace(
346
- name="EmbeddingsEvaluator",
347
- kind="tool",
348
- status_code="success",
349
- metadata={
350
- "framework": "Llama-Index",
351
- "evaluator": "SemanticSimilarityEvaluator",
352
- "similarity_mode": "DEFAULT",
353
- "similarity_threshold": 0.7,
354
- "similarity_results": {
355
- "eval_context_query_only": eval_context_query_only,
356
- "eval_context_extended_query": eval_context_extended_query,
357
- },
358
- "model_name": self.emb_model._model_name,
359
- },
360
- start_time_ms=eval_start_time_ms,
361
- end_time_ms=agent_end_time_ms,
362
- inputs={"query": query_selected},
363
- outputs={"response": context_from_index_selected},
364
- )
365
- agent_span.add_child(eval_span)
366
-
367
- except Exception as e:
368
- logger.error(f"Exception {e} occured when retriving context\n")
369
-
370
- llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
371
- result = "Something went wrong. Please try again later."
372
- root_span.add_inputs_and_outputs(
373
- inputs={"query": query}, outputs={"result": result, "exception": e}
374
  )
375
- root_span._span.status_code="fail"
376
- root_span._span.end_time_ms = llm_end_time_ms
377
- root_span.log(name="llm_app_trace")
378
- return result
379
-
380
- logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
381
-
382
- agent_span.add_inputs_and_outputs(
383
- inputs={"query": query}, outputs={"result": context_from_index_selected}
384
- )
385
- agent_span._span.status_code="success"
386
- agent_span._span.end_time_ms = agent_end_time_ms
387
- root_span.add_child(agent_span)
388
-
389
- # generate text with prompt template to roleplay myself
390
- prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
391
- try:
392
- response = genai.generate_text(
393
- prompt=prompt_with_context,
394
- safety_settings=[
395
- {
396
- 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
397
- 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
398
- },
399
- ],
400
- temperature=0.9,
401
  )
402
- result = response.result
403
- success_flag = "success"
404
- if result is None:
405
- result = "Seems something went wrong. Please try again later."
406
- logger.error(f"Result with 'None' received\n")
407
- success_flag = "fail"
408
-
409
- except Exception as e:
410
- result = "Seems something went wrong. Please try again later."
411
- logger.error(f"Exception {e} occured\n")
412
- success_flag = "fail"
413
-
414
- # w&b trace llm
415
- llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
416
- llm_span = Trace(
417
- name="LLM",
418
- kind="llm",
419
- status_code=success_flag,
420
- start_time_ms=agent_end_time_ms,
421
- end_time_ms=llm_end_time_ms,
422
- inputs={"input": prompt_with_context},
423
- outputs={"result": result},
424
- )
425
- root_span.add_child(llm_span)
426
-
427
- # w&b finalize trace
428
- root_span.add_inputs_and_outputs(
429
- inputs={"query": query}, outputs={"result": result}
430
- )
431
- root_span._span.end_time_ms = llm_end_time_ms
432
- root_span.log(name="llm_app_trace")
433
-
434
- return result
 
1
+ from src.llamaindex_palm import LlamaIndexPaLM, LlamaIndexPaLMText
 
 
2
 
3
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ from typing import List
6
+ import time
7
  import logging
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Llama-Index LLM
10
+ llm_backend = LlamaIndexPaLMText(model_kwargs={'temperature': 0.8})
11
+ llm = LlamaIndexPaLM(model=llm_backend)
12
+ llm.get_index_from_pinecone()
13
 
14
+ # Gradio
15
+ chat_history = []
 
 
 
 
 
 
16
 
17
+ def clear_chat() -> None:
18
+ global chat_history
19
+ chat_history = []
20
+ return None
21
 
22
+ def get_chat_history(chat_history: List[str]) -> str:
23
+ ind = 0
24
+ formatted_chat_history = ""
25
+ for message in chat_history:
26
+ formatted_chat_history += f"User: \n{message}\n" if ind % 2 == 0 else f"Bot: \n{message}\n"
27
+ ind += 1
28
+ return formatted_chat_history
29
 
30
+ def generate_text(prompt: str, llamaindex_llm: LlamaIndexPaLM):
31
+ global chat_history
 
32
 
33
+ logger.info("Generating Message...")
34
+ logger.info(f"User Message:\n{prompt}\n")
 
35
 
36
+ result = llamaindex_llm.generate_text(prompt, chat_history)
37
+ chat_history.append(prompt)
38
+ chat_history.append(result)
 
 
39
 
40
+ logger.info(f"Replied Message:\n{result}\n")
41
+ return result
42
+
43
+ if __name__ == "__main__":
44
+ logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p', level=logging.INFO)
45
+ logger = logging.getLogger('app')
46
+
47
+ try:
48
+ with gr.Blocks() as app:
49
+ chatbot = gr.Chatbot(
50
+ bubble_full_width=False,
51
+ container=True,
52
+ show_share_button=False,
53
+ avatar_images=[None, './asset/akag-g-only.png']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
+ msg = gr.Textbox(
56
+ show_label=False,
57
+ label="Type your message...",
58
+ placeholder="Hi Gerard, can you introduce yourself?",
59
+ container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
+ with gr.Row():
62
+ clear = gr.Button("Clear", scale=1)
63
+ send = gr.Button(
64
+ value="",
65
+ variant="primary",
66
+ icon="./asset/send-message.png",
67
+ scale=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
 
70
+ def user(user_message, history):
71
+ return "", history + [[user_message, None]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def bot(history):
74
+ bot_message = generate_text(history[-1][0], llm)
75
+ history[-1][1] = ""
76
+ for character in bot_message:
77
+ history[-1][1] += character
78
+ time.sleep(0.01)
79
+ yield history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
82
+ bot, chatbot, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
+ send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
85
+ bot, chatbot, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
+ clear.click(clear_chat, None, chatbot, queue=False)
88
+
89
+ gr.HTML("""
90
+ <p><center><i>Disclaimer: This is a RAG app for demostration purpose. LLM hallucination might occur.</i></center></p>
91
+ <p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>
92
+ """)
93
+
94
+ app.queue()
95
+ app.launch()
96
+ except Exception as e:
97
+ logger.exception(e)