Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import textwrap | |
st.title('GPT2 trained on tg chat') | |
model_directory = 'finetuned/' # Directory where the model is located | |
model = GPT2LMHeadModel.from_pretrained(model_directory, use_safetensors=True) | |
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') | |
def predict(text, max_len=100, num_beams=10, temperature=1.5, top_p=0.7): | |
with torch.inference_mode(): | |
prompt = text | |
prompt = tokenizer.encode(prompt, return_tensors='pt') | |
out = model.generate( | |
input_ids=prompt, | |
max_length=max_len, | |
num_beams=num_beams, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
no_repeat_ngram_size=1, | |
num_return_sequences=1, | |
).cpu().numpy() | |
return textwrap.fill(tokenizer.decode(out[0])) | |
prompt = st.text_input("Твоя фраза") | |
col = st.columns(4) | |
with col[0]: | |
max_len = st.slider("Text len", 20, 200, 100) | |
with col[1]: | |
num_beams = st.slider("Beams", 0.1, 1., 0.5) | |
with col[2]: | |
temperature = st.slider("Temperature", 0.1, 0.9, 0.35) | |
with col[3]: | |
top_p = st.slider("Top-p", 0.1, 1.0, 0.7) | |
submit = st.button('Сгенерировать ответ') | |
if submit: | |
if prompt: | |
pred = predict(prompt, max_len=max_len, num_beams=int(num_beams * 20), temperature=(1-temperature) * 5, top_p=top_p) | |
st.write(pred) | |