QueryYourDocs / app_inference.py
LVKinyanjui's picture
Another permission error fix.
ed42113 verified
raw
history blame
No virus
1.95 kB
import streamlit as st
import transformers, torch
import json, os
from huggingface_hub import login
# CONSTANTS
MAX_NEW_TOKENS = 256
SYSTEM_MESSAGE = "You are a hepful, knowledgeable assistant"
# ENV VARS
# To avert Permision error with transformer and hf models
# os.environ['SENTENCE_TRANSFORMERS_HOME'] = '.'
token = os.getenv("HF_TOKEN_WRITE") # Must be a write token
# STREAMLIT UI AREA
st.write("## Ask your Local LLM")
text_input = st.text_input("Query", value="Why is the sky Blue")
submit = st.button("Submit")
# MODEL AREA
# Use the token to authenticate
login(token=token, write_permission=True)
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@st.cache_resource
def load_model():
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
pipeline = load_model()
message_store_path = "messages.jsonl"
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
]
if os.path.exists(message_store_path):
with open(message_store_path, "r", encoding="utf-8") as f:
messages = [json.loads(line) for line in f]
print(messages)
@st.cache_data
def infer(message: str, messages: list[dict]):
"""
Params:
message: Most recent query to the llm.
messages: Chat history up to current point properly formatted like
{"role": "user", "content": "What is your name?"}
"""
messages.append({"role": "user", "content": message})
# Perfom inference
output = pipeline(
messages,
max_new_tokens=MAX_NEW_TOKENS)
# Save the newly updated messages object
with open(message_store_path, "w", encoding="utf-8") as f:
for line in output:
json.dump(line, f)
f.write("\n")
return output[-1]['generated_text'][-1]['content']
if submit:
response = infer(text_input, messages)
response