Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers, torch | |
import json, os | |
from huggingface_hub import login | |
# CONSTANTS | |
MAX_NEW_TOKENS = 256 | |
SYSTEM_MESSAGE = "You are a hepful, knowledgeable assistant" | |
# ENV VARS | |
# To avert Permision error with transformer and hf models | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '.' | |
token = os.getenv("HF_TOKEN_READ") | |
# STREAMLIT UI AREA | |
st.write("## Ask your Local LLM") | |
text_input = st.text_input("Query", value="Why is the sky Blue") | |
submit = st.button("Submit") | |
# MODEL AREA | |
# Use the token to authenticate | |
login(token=token, write_permission=True) | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
def load_model(): | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model_id, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
pipeline = load_model() | |
message_store_path = "messages.jsonl" | |
messages = [ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
] | |
if os.path.exists(message_store_path): | |
with open(message_store_path, "r", encoding="utf-8") as f: | |
messages = [json.loads(line) for line in f] | |
print(messages) | |
def infer(message: str, messages: list[dict]): | |
""" | |
Params: | |
message: Most recent query to the llm. | |
messages: Chat history up to current point properly formatted like | |
{"role": "user", "content": "What is your name?"} | |
""" | |
messages.append({"role": "user", "content": message}) | |
# Perfom inference | |
output = pipeline( | |
messages, | |
max_new_tokens=MAX_NEW_TOKENS) | |
# Save the newly updated messages object | |
with open(message_store_path, "w", encoding="utf-8") as f: | |
for line in output: | |
json.dump(line, f) | |
f.write("\n") | |
return output[-1]['generated_text'][-1]['content'] | |
if submit: | |
response = infer(text_input, messages) | |
response |