sagar007 commited on
Commit
6c2fd08
1 Parent(s): 8d19e1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
+
5
+ # Load the tokenizer and the model
6
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
7
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
8
+
9
+ # Load the best model weights
10
+ model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
11
+
12
+ # Set the model to evaluation mode
13
+ model.eval()
14
+
15
+ # Define the text generation function
16
+ def generate_text(prompt, max_length=50, num_return_sequences=1):
17
+ inputs = tokenizer(prompt, return_tensors='pt')
18
+ outputs = model.generate(
19
+ inputs.input_ids,
20
+ max_length=max_length,
21
+ num_return_sequences=num_return_sequences,
22
+ do_sample=True,
23
+ top_k=50,
24
+ top_p=0.95,
25
+ temperature=1.0
26
+ )
27
+ return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
28
+
29
+ # Define the Gradio interface
30
+ interface = gr.Interface(
31
+ fn=generate_text,
32
+ inputs=[
33
+ gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here..."),
34
+ gr.inputs.Slider(minimum=10, maximum=200, default=50, label="Max Length"),
35
+ gr.inputs.Slider(minimum=1, maximum=5, default=1, label="Number of Sequences")
36
+ ],
37
+ outputs=gr.outputs.Textbox(),
38
+ title="GPT-2 Text Generator",
39
+ description="Enter a prompt to generate text using GPT-2.",
40
+ )
41
+
42
+ # Launch the Gradio interface
43
+ interface.launch()