hehetest / app.py
hewoo's picture
Update app.py
5be7074 verified
raw
history blame
2.85 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import Chroma
import os
import psutil
import time
# Hugging Face λͺ¨λΈ ID
model_id = "hewoo/hehehehe"
# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ (토큰 없이 μ‚¬μš©)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# ν…μŠ€νŠΈ 생성 νŒŒμ΄ν”„λΌμΈ μ„€μ •
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.3, top_p=0.85, top_k=40, repetition_penalty=1.2)
# μ‚¬μš©μž μ •μ˜ μž„λ² λ”© 클래슀 생성
class CustomEmbedding:
def __init__(self, model):
self.model = model
def embed_query(self, text):
return self.model.encode(text, convert_to_tensor=True).tolist()
def embed_documents(self, texts):
return [self.model.encode(text, convert_to_tensor=True).tolist() for text in texts]
# μž„λ² λ”© λͺ¨λΈ 및 벑터 μŠ€ν† μ–΄ μ„€μ •
embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask")
embedding_function = CustomEmbedding(embedding_model)
# Chroma 벑터 μŠ€ν† μ–΄ μ„€μ •
persist_directory = "./chroma_batch_vectors"
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# μ§ˆλ¬Έμ— λŒ€ν•œ 응닡 생성 ν•¨μˆ˜
def generate_response(user_input):
start_time = time.time() # μ‹œμž‘ μ‹œκ°„ 기둝
# λ¬Έμ„œ 검색 및 λ§₯락 생성
search_results = retriever.get_relevant_documents(user_input)
context = "\n".join([result.page_content for result in search_results])
input_text = f"""μ•„λž˜λŠ” ν•œκ΅­μ–΄λ‘œλ§Œ λ‹΅λ³€ν•˜λŠ” μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€.
μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•΄ 제곡된 λ§₯락을 λ°”νƒ•μœΌλ‘œ μ •ν™•ν•˜κ³  μžμ„Έν•œ 닡변을 ν•œκ΅­μ–΄λ‘œ μž‘μ„±ν•˜μ„Έμš”.
λ§₯락: {context}
질문: {user_input}
λ‹΅λ³€:"""
# 응닡 생성
response = pipe(input_text)[0]["generated_text"]
end_time = time.time() # λλ‚œ μ‹œκ°„ 기둝
response_time = end_time - start_time # 응닡 μ‹œκ°„ 계산
# λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λͺ¨λ‹ˆν„°λ§
memory_info = psutil.virtual_memory()
memory_usage = memory_info.percent # λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ λΉ„μœ¨(%)
return response, response_time, memory_usage
# Streamlit μ•± UI
st.title("챗봇 데λͺ¨")
st.write("Llama 3.2-3B λͺ¨λΈμ„ μ‚¬μš©ν•œ μ±—λ΄‡μž…λ‹ˆλ‹€. μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ μ£Όμ„Έμš”.")
# μ‚¬μš©μž μž…λ ₯ λ°›κΈ°
user_input = st.text_input("질문")
if user_input:
response, response_time, memory_usage = generate_response(user_input)
st.write("챗봇 응닡:", response)
st.write(f"응닡 μ‹œκ°„: {response_time:.2f}초")
st.write(f"ν˜„μž¬ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {memory_usage}%")