import streamlit as st import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import textwrap st.title('GPT2 trained on tg chat') model_directory = 'finetuned/' # Directory where the model is located model = GPT2LMHeadModel.from_pretrained(model_directory, use_safetensors=True) tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') def predict(text, max_len=100, num_beams=10, temperature=1.5, top_p=0.7): with torch.inference_mode(): prompt = text prompt = tokenizer.encode(prompt, return_tensors='pt') out = model.generate( input_ids=prompt, max_length=max_len, num_beams=num_beams, do_sample=True, temperature=temperature, top_p=top_p, no_repeat_ngram_size=1, num_return_sequences=1, ).cpu().numpy() return textwrap.fill(tokenizer.decode(out[0])) prompt = st.text_input("Твоя фраза") col = st.columns(4) with col[0]: max_len = st.slider("Text len", 20, 200, 100) with col[1]: num_beams = st.slider("Beams", 0.1, 1., 0.5) with col[2]: temperature = st.slider("Temperature", 0.1, 0.9, 0.35) with col[3]: top_p = st.slider("Top-p", 0.1, 1.0, 0.7) submit = st.button('Сгенерировать ответ') if submit: if prompt: pred = predict(prompt, max_len=max_len, num_beams=int(num_beams * 20), temperature=(1-temperature) * 5, top_p=top_p) st.write(pred)