Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from itertools import chain | |
import gradio as gr | |
import torch | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(device) | |
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall") | |
model = AutoModelForCausalLM.from_pretrained("uer/gpt2-chinese-cluecorpussmall").to(device) | |
def generate_text(prompt,length=500): | |
inputs = tokenizer(prompt,add_special_tokens=False, return_tensors="pt").to(device) | |
txt = tokenizer.decode(model.generate(inputs["input_ids"], | |
max_length=length, | |
num_beams=2, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
pad_token_id = 0 | |
)[0]) | |
#Replace text | |
replacements = { | |
'[': "", | |
']': "", | |
'S': "", | |
'E': "", | |
'P': "", | |
'U': "", | |
'N': "", | |
'K': "" | |
} | |
new_text = ''.join(chain.from_iterable(replacements.get(word, [word]) for word in txt)) | |
return new_text | |
with gr.Blocks() as web: | |
gr.Markdown("<h1><center>Andrew Lim Chinese stories </center></h1>") | |
gr.Markdown("""<h2><center>让人工智能讲故事:<br><br> | |
<img src=https://images.unsplash.com/photo-1550450339-e7a4787a2074?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1252&q=80></center></h2>""") | |
gr.Markdown("""<center>******</center>""") | |
input_text = gr.Textbox(label="故事的开始", value="在空中飞翔", lines=6) | |
buton = gr.Button("Submit ") | |
output_text = gr.Textbox(lines=6, label="人工智能讲一个故事 :") | |
buton.click(generate_text, inputs=[input_text], outputs=output_text) | |
web.launch() |