File size: 2,072 Bytes
28ebfe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from langchain_community.llms import GooglePalm
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import FewShotPromptTemplate
import os
from few_shot import fewshots
from dotenv import load_dotenv
load_dotenv()

def get_few_shot_db_chain():
    llm = GooglePalm(google_api_key=os.environ["google_api_key"], temperature=0.2)
    db_user = "root"
    db_password = "Kautilya1414"
    db_host = 'localhost'
    db_name = 'atliq_tshirts'

    uri = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"
    db = SQLDatabase.from_uri(uri)
    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-V2')
    vectorize = [" ".join(example.values()) for example in fewshots]
    vectorstore = Chroma.from_texts(vectorize, embeddings, metadatas=fewshots)
    example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=2)

    example_prompt = PromptTemplate(
        input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
        template="\nQuestion:{Question}\nSQLQuery:{SQLQuery}\nSQLResult:{SQLResult}\nAnswer:{Answer}"
    )
    few_shot_temp = FewShotPromptTemplate(example_selector=example_selector,
                                          example_prompt=example_prompt,
                                          prefix=_mysql_prompt,
                                          suffix=PROMPT_SUFFIX,
                                          input_variables=["input", "table_info", "top_k"])
    new_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_temp)
    return new_chain

if __name__ == "__main__":
    new_chain =  get_few_shot_db_chain()
    print(new_chain.run("how many white color Levi t-shirts we have?"))