#!/usr/bin/env python # coding: utf-8 # importing required libraries from transformers import pipeline, GPT2TokenizerFast from torch import bfloat16 import gradio as gr WARNING = """Whoooa there, partner! Before you dive in, let's establish some ground rules:\nBy using this application, you are stating that you are the 'Big Cheese', the 'Head Honcho', the 'Master of Your Domain', in short, the sole user of this app. Now, don't go blaming us or any other parties if the results are not to your liking, or lead to any unforeseen circumstances.\nIn the simplest terms, the moment you input any data on this page you accept full responsibility for any and all usage of this application. Just like when you eat that extra slice of pizza at midnight, you're the one who's responsible for the extra workout the next day, not the pizza guy!""" # pipeline function with default values def story(prompt="When I was young", model_name = "coffeeee/nsfw-story-generator2", story_length=300): """ model_name: full model name to be used from the hugging face models, default: coffeeee/nsfw-story-generator2; prompt: user input to to extend the story based on the prompt, default: 'When I was young'; story_length: number of maximum tokens to generate, function_default: 50, modified_default: 300; """ # create a pipeline for the model create = pipeline(model=model_name, torch_dtype=bfloat16, device_map="auto", pad_token_id=GPT2TokenizerFast.from_pretrained("gpt2").eos_token_id) # return the output from the model return create(prompt, max_new_tokens=story_length)[0]['generated_text'] # block framework to customize the io page with gr.Blocks() as app: gr.Markdown("# Story Generator, a delightful combo of HuggingFace API and Gradio.io") gr.Label(value=WARNING, label="Disclaimer!!!") story_start = gr.Textbox(label="Begin the storyline", value="This is about the time when I was 15 years old and living with") selected_model = gr.Textbox(value="coffeeee/nsfw-story-generator2", label="Hugging Face Model To Use") story_len = gr.Slider(100,500, label="Arc length") gen_story = gr.Textbox(label="Story", lines=15, max_lines=20) greet_btn = gr.Button("Entertain") greet_btn.click(fn=story, inputs=[story_start, selected_model, story_len], outputs=gen_story) app.queue(api_open=False) app.launch(inline=False, share=False)