Text2SQL / pasta.py
vbzvibin's picture
Upload 4 files
6689e6b
raw
history blame
4.9 kB
# -*- coding: utf-8 -*-
"""
Created on Fri May 26 14:07:22 2023
@author: vibin
"""
import streamlit as st
from pandasql import sqldf
import pandas as pd
import re
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import re
@st.cache_resource()
def tapas_model():
return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq"))
@st.cache_resource()
def prepare_input(question: str, table: List[str]):
table_prefix = "table:"
question_prefix = "question:"
join_table = ",".join(table)
inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids
return input_ids
@st.cache_resource()
def inference(question: str, table: List[str]) -> str:
input_data = prepare_input(question=question, table=table)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
@st.cache_resource()
def tokmod(tok_md):
tkn = AutoTokenizer.from_pretrained(tok_md)
mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md)
return(tkn,mdl)
### Main
nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"])
if nav == "TAPAS":
col1 , col2, col3 = st.columns(3)
col2.title("TAPAS")
col3 , col4 = st.columns([3,12])
col4.text("Tabular Data Text Extraction using text")
table = pd.read_csv("data.csv")
table = table.astype(str)
st.text("DataSet - ")
st.dataframe(table,width=3000,height= 400)
st.title("")
lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"]
v2 = st.selectbox("Choose your text",lst_q,index = 0)
st.title("")
sql_txt = st.text_area("TAPAS Input",v2)
if st.button("Predict"):
tqa = tapas_model()
txt_sql = tqa(table=table, query=sql_txt)["answer"]
st.text("Output - ")
st.success(f"{txt_sql}")
# st.write(all_students)
elif nav == "Text2SQL":
### Function
col1 , col2, col3 = st.columns(3)
col2.title("Text2SQL")
col3 , col4 = st.columns([1,20])
col4.text("Text will be converted to SQL Query and can extract the data from DataSet")
# Import Data
df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape')
st.title("")
st.text("DataSet - ")
st.dataframe(df_qna,width=3000,height= 500)
st.title("")
lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"]
v2 = st.selectbox("Choose your text",lst_q,index = 0)
st.title("")
sql_txt = st.text_area("Text for SQL Conversion",v2)
if st.button("Predict"):
tok_model = "juierror/flan-t5-text2sql-with-schema"
tokenizer,model = tokmod(tok_model)
# text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD"
table_name = "df_qna"
table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"]
txt_sql = inference(question=sql_txt, table=table_col)
### SQL Modification
sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"]
txt_sql = txt_sql.replace("table",table_name)
sql_quotes = []
for match in re.finditer("=",txt_sql):
new_txt = txt_sql[match.span()[1]+1:]
try:
match2 = re.search("AND",new_txt)
sql_quotes.append((new_txt[:match2.span()[0]]).strip())
except:
sql_quotes.append(new_txt.strip())
for i in sql_quotes:
qts = "'" + i + "'"
txt_sql = txt_sql.replace(i, qts)
for r in sql_avg:
if r in txt_sql:
rr = re.search(rf"{r} (\w+)", txt_sql)
init = " " + rr[1]
qts = "(" + rr[1] + ")"
txt_sql = txt_sql.replace(init,qts)
else:
pass
st.success(f"{txt_sql}")
all_students = sqldf(txt_sql)
st.text("Output - ")
st.write(all_students)