persianllama / app.py
mostafaamiri's picture
Update app.py
c77c8b6 verified
raw
history blame
1.53 kB
from langchain.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import gradio as gr
import re
import os
n_gpu_layers = 40 # Change this value based on your model and your GPU VRAM pool.
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
n_ctx=2048
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
path = "persian_llama_7b.Q8_K_M.gguf"
llm = LlamaCpp(
model_path= path,
n_gpu_layers=n_gpu_layers, n_batch=n_batch,
callback_manager=callback_manager,
verbose=True,
n_ctx=n_ctx,
temperature=0.1,
max_tokens=200,
top_p=1,
)
def generate_output(text):
result = ""
for s in llm.stream(text):
result += s
yield result
def clear():
return "", ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
inputs=gr.Textbox(label="ورودی",placeholder="سوال خود را وارد کنید",rtl=True)
with gr.Row():
submit_btn= gr.Button("ارسال", variant="primary")
clear_btn = gr.ClearButton(value="پاک کردن", variant="secondary")
with gr.Row():
outputs=gr.Textbox(label="خروجی",rtl=True)
submit_btn.click(fn=generate_output,
inputs= [inputs],
outputs= [outputs])
clear_btn.click(fn=clear, inputs=[], outputs=[inputs, outputs])
demo.launch(server_name='0.0.0.0',share=True)