gsvc commited on
Commit
ff35878
1 Parent(s): 481bdb6

Update edubot.py

Browse files
Files changed (1) hide show
  1. edubot.py +6 -5
edubot.py CHANGED
@@ -4,7 +4,7 @@ from langchain.vectorstores import FAISS
4
  from langchain.llms import CTransformers
5
  from langchain.chains import RetrievalQA
6
  from config import *
7
-
8
  class EduBotCreator:
9
 
10
  def __init__(self):
@@ -18,12 +18,12 @@ class EduBotCreator:
18
  self.model_type = MODEL_TYPE
19
  self.max_new_tokens = MAX_NEW_TOKENS
20
  self.temperature = TEMPERATURE
21
-
22
  def create_custom_prompt(self):
23
  custom_prompt_temp = PromptTemplate(template=self.prompt_temp,
24
  input_variables=self.input_variables)
25
  return custom_prompt_temp
26
-
27
  def load_llm(self):
28
  llm = CTransformers(
29
  model = self.model_ckpt,
@@ -32,7 +32,7 @@ class EduBotCreator:
32
  temperature = self.temperature
33
  )
34
  return llm
35
-
36
  def load_vectordb(self):
37
  hfembeddings = HuggingFaceEmbeddings(
38
  model_name=self.embedder,
@@ -41,7 +41,8 @@ class EduBotCreator:
41
 
42
  vector_db = FAISS.load_local(self.vector_db_path, hfembeddings)
43
  return vector_db
44
-
 
45
  def create_bot(self, custom_prompt, vectordb, llm):
46
  retrieval_qa_chain = RetrievalQA.from_chain_type(
47
  llm=llm,
 
4
  from langchain.llms import CTransformers
5
  from langchain.chains import RetrievalQA
6
  from config import *
7
+ import streamlit as st
8
  class EduBotCreator:
9
 
10
  def __init__(self):
 
18
  self.model_type = MODEL_TYPE
19
  self.max_new_tokens = MAX_NEW_TOKENS
20
  self.temperature = TEMPERATURE
21
+
22
  def create_custom_prompt(self):
23
  custom_prompt_temp = PromptTemplate(template=self.prompt_temp,
24
  input_variables=self.input_variables)
25
  return custom_prompt_temp
26
+ @st.cache_resource()
27
  def load_llm(self):
28
  llm = CTransformers(
29
  model = self.model_ckpt,
 
32
  temperature = self.temperature
33
  )
34
  return llm
35
+ @st.cache_resource()
36
  def load_vectordb(self):
37
  hfembeddings = HuggingFaceEmbeddings(
38
  model_name=self.embedder,
 
41
 
42
  vector_db = FAISS.load_local(self.vector_db_path, hfembeddings)
43
  return vector_db
44
+
45
+
46
  def create_bot(self, custom_prompt, vectordb, llm):
47
  retrieval_qa_chain = RetrievalQA.from_chain_type(
48
  llm=llm,