history_mistery / pages /📖History_Mystery.py
SaviAnna's picture
Update pages/📖History_Mystery.py
6d51d4f verified
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
import base64
import plotly.express as px
df = px.data.iris()
@st.cache_data
def get_img_as_base64(file):
with open(file, "rb") as f:
data = f.read()
return base64.b64encode(data).decode()
page_bg_img = f"""
<style>
[data-testid="stAppViewContainer"] > .main {{
background-image: url(https://slack-imgs.com/?c=1&o1=ro&url=https%3A%2F%2Fwallpapercave.com%2Fwp%2Fwp6480460.jpg");
background-size: 115%;
background-position: top left;
background-repeat: no-repeat;
background-attachment: local;
}}
[data-testid="stSidebar"] > div:first-child {{
background-image: url("https://ibb.co/ZBkdJRg");
background-size: 115%;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
}}
[data-testid="stHeader"] {{
background: rgba(0,0,0,0);
}}
[data-testid="stToolbar"] {{
right: 2rem;
}}
div.css-1n76uvr.e1tzin5v0 {{
background-color: rgba(238, 238, 238, 0.5);
border: 10px solid #EEEEEE;
padding: 5% 5% 5% 10%;
border-radius: 5px;
}}
</style>
"""
st.markdown(page_bg_img, unsafe_allow_html=True)
st.title("""
History Mystery
""")
# Добавление слайдера
temp = st.slider("Градус дичи", 1.0, 20.0, 5.0)
sen_quan = st.slider(" Длина сгенерированного отрывка", 20, 100, 5)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
#@st.cache_resource(allow_output_mutation=True)
def load_gpt():
model_GPT = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions=False,
output_hidden_states=False,
)
tokenizer_GPT = GPT2Tokenizer.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions=False,
output_hidden_states=False,
)
model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
return model_GPT, tokenizer_GPT
#model, tokenizer = load_gpt()
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
# Преобразование входной строки в токены
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
# Генерация текста
output = model_GPT.generate(input_ids=input_ids, max_length=100, num_beams=2*sen_quan, do_sample=True,
temperature=temp, top_k=60, top_p=0.6, no_repeat_ngram_size=4,
num_return_sequences=sen_quan)
# Декодирование сгенерированного текста
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_GPT, tokenizer_GPT = load_gpt()
st.write("""
# GPT-3 генерация текста
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века среди рыцарей")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button = st.button("За работу!")
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
st.subheader("Продолжение:")
st.write(generated_text)
if __name__ == "__main__":
main()