Lyte commited on
Commit
433e378
1 Parent(s): e11397d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Load the model and tokenizer
6
+ model_name = "Lyte/Llama-3.2-3B-Overthinker"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
9
+
10
+ def generate_response_stream(prompt, max_tokens, temperature, top_p, repeat_penalty, num_steps=4):
11
+ messages = [{"role": "user", "content": prompt}]
12
+
13
+ # Generate reasoning
14
+ reasoning_template = tokenizer.apply_chat_template(messages, tokenize=False, add_reasoning_prompt=True)
15
+ reasoning_inputs = tokenizer(reasoning_template, return_tensors="pt").to(model.device)
16
+
17
+ reasoning_ids = model.generate(
18
+ **reasoning_inputs,
19
+ max_new_tokens=max_tokens // 3,
20
+ temperature=temperature,
21
+ top_p=top_p,
22
+ repetition_penalty=repeat_penalty
23
+ )
24
+ reasoning_output = tokenizer.decode(reasoning_ids[0, reasoning_inputs.input_ids.shape[1]:], skip_special_tokens=True)
25
+ yield reasoning_output, "", ""
26
+
27
+ # Generate thinking (step-by-step and verifications)
28
+ messages.append({"role": "reasoning", "content": reasoning_output})
29
+ thinking_template = tokenizer.apply_chat_template(messages, tokenize=False, add_thinking_prompt=True, num_steps=num_steps)
30
+ thinking_inputs = tokenizer(thinking_template, return_tensors="pt").to(model.device)
31
+
32
+ thinking_ids = model.generate(
33
+ **thinking_inputs,
34
+ max_new_tokens=max_tokens // 3,
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ repetition_penalty=repeat_penalty
38
+ )
39
+ thinking_output = tokenizer.decode(thinking_ids[0, thinking_inputs.input_ids.shape[1]:], skip_special_tokens=True)
40
+ yield reasoning_output, thinking_output, ""
41
+
42
+ # Generate final answer
43
+ messages.append({"role": "thinking", "content": thinking_output})
44
+ answer_template = tokenizer.apply_chat_template(messages, tokenize=False, add_answer_prompt=True)
45
+ answer_inputs = tokenizer(answer_template, return_tensors="pt").to(model.device)
46
+
47
+ answer_ids = model.generate(
48
+ **answer_inputs,
49
+ max_new_tokens=max_tokens // 3,
50
+ temperature=temperature,
51
+ top_p=top_p,
52
+ repetition_penalty=repeat_penalty
53
+ )
54
+ answer_output = tokenizer.decode(answer_ids[0, answer_inputs.input_ids.shape[1]:], skip_special_tokens=True)
55
+ yield reasoning_output, thinking_output, answer_output
56
+
57
+ with gr.Blocks() as iface:
58
+ gr.Markdown("# Llama-3.2-3B Overthinker Customizable Steps, Please Duplicate and run with GPU if you can! T4 is fine!")
59
+ gr.Markdown("Generate responses using the Llama-3.2-3B Reasoning model.")
60
+
61
+ with gr.Row():
62
+ with gr.Column(scale=2):
63
+ prompt = gr.Textbox(lines=5, label="Prompt")
64
+ generate_button = gr.Button("Generate Response")
65
+ with gr.Column(scale=1):
66
+ max_tokens = gr.Slider(minimum=512, maximum=32768, value=8192, label="Max Number of Tokens")
67
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature")
68
+ top_p = gr.Slider(minimum=0.01, maximum=0.99, value=0.95, label="Top P")
69
+ repeat_penalty = gr.Slider(minimum=0.5, maximum=2, value=1.1, label="Repeat Penalty")
70
+ num_steps = gr.Slider(minimum=1, maximum=10, value=4, label="Max Number of Steps")
71
+
72
+ reasoning_output = gr.Textbox(lines=5, label="Reasoning")
73
+ with gr.Accordion("Thinking Process", open=False):
74
+ thinking_output = gr.Textbox(lines=10, label="Step-by-Step Thinking")
75
+ answer_output = gr.Textbox(lines=5, label="Final Answer")
76
+
77
+ generate_button.click(
78
+ fn=generate_response_stream,
79
+ inputs=[prompt, max_tokens, temperature, top_p, repeat_penalty, num_steps],
80
+ outputs=[reasoning_output, thinking_output, answer_output]
81
+ )
82
+
83
+ iface.launch()