ZoniaQwen / app.py
ZoniaChatbot's picture
Update app.py
793c82d verified
raw
history blame contribute delete
No virus
3.59 kB
import argparse
import os
import gradio as gr
from loguru import logger
from similarities import BertSimilarity, BM25Similarity
from chatpdf import Rag
pwd_path = os.path.abspath(os.path.dirname(__file__))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
parser.add_argument("--gen_model_type", type=str, default="auto")
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
parser.add_argument("--lora_model", type=str, default=None)
parser.add_argument("--rerank_model_name", type=str, default="")
parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
parser.add_argument("--device", type=str, default=None)
#parser.add_argument("--int4", action='store_true', help="use int4 quantization")
#parser.add_argument("--int8", action='store_true', help="use int8 quantization")
parser.add_argument("--chunk_size", type=int, default=220)
parser.add_argument("--chunk_overlap", type=int, default=0)
parser.add_argument("--num_expand_context_chunk", type=int, default=1)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=8082)
parser.add_argument("--share", action='store_true', default=True, help="share model")
args = parser.parse_args()
logger.info(args)
# Inicializar el modelo
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
model = Rag(
similarity_model=sim_model,
generate_model_type=args.gen_model_type,
generate_model_name_or_path=args.gen_model_name,
lora_model_name_or_path=args.lora_model,
corpus_files=args.corpus_files.split(','),
device=args.device,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
num_expand_context_chunk=args.num_expand_context_chunk,
rerank_model_name_or_path=args.rerank_model_name,
)
logger.info(f"chatpdf model: {model}")
def predict_stream(message, history):
history_format = []
for human, assistant in history:
history_format.append([human, assistant])
model.history = history_format
for chunk in model.predict_stream(message):
yield chunk
def predict(message, history):
logger.debug(message)
response, reference_results = model.predict(message)
r = response + "\n\n" + '\n'.join(reference_results)
logger.debug(r)
return r
chatbot_stream = gr.Chatbot(
height=600,
avatar_images=(
os.path.join(pwd_path, "assets/user.png"),
os.path.join(pwd_path, "assets/Logo1.png"),
), bubble_full_width=False)
# Actualizar el t铆tulo y la descripci贸n
title = " 馃ChatPDF Zonia馃 "
# description = "Enlace en Github: [shibing624/ChatPDF](https://github.com/shibing624/ChatPDF)"
css = """.toast-wrap { display: none !importante } """
examples = ['Puede hablarme del PNL?', 'Introducci贸n a la PNL']
chat_interface_stream = gr.ChatInterface(
predict,
textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), # A帽adir submit=True
title=title,
# description=description,
chatbot=chatbot_stream,
css=css,
examples=examples,
theme='soft',
)
# Lanzar la aplicaci贸n sin `server_name` ni `server_port`
chat_interface_stream.launch()