coffeeee's picture
switch to using hf model
fe765e3
raw
history blame contribute delete
No virus
3.57 kB
import gradio as gr
import nltk
import string
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
import random
nltk.download('punkt')
response_length = 200
sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.truncation_side = 'right'
# model = GPT2LMHeadModel.from_pretrained('checkpoint-10000')
model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator')
generation_config = GenerationConfig.from_pretrained('gpt2-medium')
generation_config.max_new_tokens = response_length
generation_config.pad_token_id = generation_config.eos_token_id
def generate_response(outputs, new_prompt):
story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""
set_seed(random.randint(0, 4000000000))
inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
return_tensors='pt', truncation=True,
max_length=1024 - response_length)
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
outputs.append(response)
return {
user_outputs: outputs,
story: (story_so_far + "\n" if story_so_far else "") + response,
prompt: None
}
def undo(outputs):
outputs = outputs[:-1] if outputs else []
return {
user_outputs: outputs,
story: "\n".join(outputs) if outputs else None
}
def clean_paragraph(entry):
paragraphs = entry.split('\n')
for i in range(len(paragraphs)):
split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
paragraphs[i] = " ".join(split_sentences[:-1])
return capitalize_first_char("\n".join(paragraphs))
def reset():
return {
user_outputs: [],
story: None
}
def capitalize_first_char(entry):
for i in range(len(entry)):
if entry[i].isalpha():
return entry[:i] + entry[i].upper() + entry[i + 1:]
return entry
with gr.Blocks(theme=gr.themes.Default(text_size='lg', font=[gr.themes.GoogleFont("Bitter"), "Arial", "sans-serif"])) as demo:
placeholder_text = '''
Disclaimer: everything this model generates is a work of fiction.
Content from this model WILL generate inappropriate and potentially offensive content.
Use at your own discretion. Please respect the Huggingface code of conduct.'''
story = gr.Textbox(label="Story", interactive=False, lines=20, placeholder=placeholder_text)
story.style(show_copy_button=True)
user_outputs = gr.State([])
prompt = gr.Textbox(label="Prompt", placeholder="Start a new story, or continue your current one!", lines=3, max_lines=3)
with gr.Row():
gen_button = gr.Button('Generate')
undo_button = gr.Button("Undo")
res_button = gr.Button("Reset")
prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
demo.launch(inbrowser=True)