Spaces:
Runtime error
Runtime error
refactor generation utils
Browse files- app.py +9 -57
- utils/__init__.py +5 -3
- utils/generation.py +58 -0
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
import datasets
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from threading import Thread
|
7 |
|
8 |
from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
|
9 |
from utils.html_utils import make_iframe, construct_embed
|
|
|
10 |
PIPE = None
|
11 |
|
12 |
intro_text = """
|
@@ -99,35 +99,6 @@ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #
|
|
99 |
print(f"loaded model {model_cp} as a pipline")
|
100 |
return pipe
|
101 |
|
102 |
-
def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
|
103 |
-
"""
|
104 |
-
Text generation function
|
105 |
-
Args:
|
106 |
-
model_ctx (str): The context to start generation from.
|
107 |
-
pipe (Pipeline): The pipeline to use for generation.
|
108 |
-
gen_kwargs (dict): The generation kwargs.
|
109 |
-
Returns:
|
110 |
-
str: The generated text. (it iterates over time)
|
111 |
-
"""
|
112 |
-
# Tokenize the model_context
|
113 |
-
model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")
|
114 |
-
|
115 |
-
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
116 |
-
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
117 |
-
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
|
118 |
-
generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
|
119 |
-
t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
120 |
-
t.start()
|
121 |
-
|
122 |
-
# Pull the generated text from the streamer, and update the model output.
|
123 |
-
model_output = ""
|
124 |
-
for new_text in streamer:
|
125 |
-
# print("step", end="")
|
126 |
-
model_output += new_text
|
127 |
-
yield model_output
|
128 |
-
streamer.on_finalized_text("stream reached the end.")
|
129 |
-
return model_output #is this ever reached?
|
130 |
-
|
131 |
def process_retn(retn):
|
132 |
return retn.split(";")[0].strip()
|
133 |
|
@@ -167,7 +138,7 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
|
|
167 |
else:
|
168 |
raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
|
169 |
|
170 |
-
generation_kwargs =
|
171 |
|
172 |
retrns = []
|
173 |
retrn_start_idx = orig_code.find("return")
|
@@ -189,14 +160,6 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
|
|
189 |
return altered_code
|
190 |
|
191 |
|
192 |
-
def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
193 |
-
gen_kwargs = {}
|
194 |
-
gen_kwargs["temperature"] = temperature
|
195 |
-
gen_kwargs["max_new_tokens"] = max_new_tokens
|
196 |
-
gen_kwargs["top_p"] = top_p
|
197 |
-
gen_kwargs["repetition_penalty"] = repetition_penalty
|
198 |
-
return gen_kwargs
|
199 |
-
|
200 |
def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
|
201 |
"""
|
202 |
Replaces the body of a function with a generated one.
|
@@ -223,27 +186,16 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
|
|
223 |
func_node = funcs_list[func_id]
|
224 |
print(f"using for generation: {func_node=}")
|
225 |
|
226 |
-
generation_kwargs =
|
|
|
|
|
227 |
|
228 |
-
func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
|
229 |
-
identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
|
230 |
body_node = func_node.child_by_field_name("body")
|
231 |
body_start_idx, body_end_idx = node_str_idx(body_node)
|
232 |
-
model_context = identifier_str # base case
|
233 |
-
|
234 |
-
docstring = get_docstrings(func_node) #might be empty?
|
235 |
-
if docstring:
|
236 |
-
model_context = model_context + "\n" + docstring
|
237 |
-
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
238 |
-
if prompt != "":
|
239 |
-
model_context = f"//avialable functions: {','.join([n.child_by_field_name('declarator').text.decode() for n in funcs_list])}\n" + model_context #prepend available functions
|
240 |
-
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
241 |
-
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
242 |
-
print(f"{model_context=}")
|
243 |
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
244 |
-
generation =
|
245 |
for i in generation:
|
246 |
-
print(f"{i=}")
|
247 |
yield model_context + i #fix in between, do all the stuff in the end?
|
248 |
generation = i[:] #seems to work
|
249 |
print(f"{generation=}")
|
@@ -253,7 +205,7 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
|
|
253 |
first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
|
254 |
except IndexError:
|
255 |
print("generation wasn't a full function.")
|
256 |
-
altered_code = old_code[:
|
257 |
return altered_code
|
258 |
altered_code = replace_function(func_node, first_gened_func)
|
259 |
yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
import datasets
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
6 |
|
7 |
from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
|
8 |
from utils.html_utils import make_iframe, construct_embed
|
9 |
+
from utils.generation import combine_generation_kwargs, stream_generation, construct_model_context
|
10 |
PIPE = None
|
11 |
|
12 |
intro_text = """
|
|
|
99 |
print(f"loaded model {model_cp} as a pipline")
|
100 |
return pipe
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def process_retn(retn):
|
103 |
return retn.split(";")[0].strip()
|
104 |
|
|
|
138 |
else:
|
139 |
raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
|
140 |
|
141 |
+
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
|
142 |
|
143 |
retrns = []
|
144 |
retrn_start_idx = orig_code.find("return")
|
|
|
160 |
return altered_code
|
161 |
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
|
164 |
"""
|
165 |
Replaces the body of a function with a generated one.
|
|
|
186 |
func_node = funcs_list[func_id]
|
187 |
print(f"using for generation: {func_node=}")
|
188 |
|
189 |
+
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
|
190 |
+
model_context = construct_model_context(func_node, prompt=prompt)
|
191 |
+
print(f"{model_context=}")
|
192 |
|
|
|
|
|
193 |
body_node = func_node.child_by_field_name("body")
|
194 |
body_start_idx, body_end_idx = node_str_idx(body_node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
196 |
+
generation = stream_generation(model_context, pipeline, generation_kwargs)
|
197 |
for i in generation:
|
198 |
+
# print(f"{i=}")
|
199 |
yield model_context + i #fix in between, do all the stuff in the end?
|
200 |
generation = i[:] #seems to work
|
201 |
print(f"{generation=}")
|
|
|
205 |
first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
|
206 |
except IndexError:
|
207 |
print("generation wasn't a full function.")
|
208 |
+
altered_code = old_code[:body_start_idx] + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
|
209 |
return altered_code
|
210 |
altered_code = replace_function(func_node, first_gened_func)
|
211 |
yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
|
utils/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char)
|
2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
|
|
3 |
|
4 |
-
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char"]
|
5 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
|
|
6 |
|
7 |
-
__all__ = tree_funcs + html_funcs
|
|
|
1 |
+
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
|
2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
3 |
+
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
4 |
|
5 |
+
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
|
6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
7 |
+
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
8 |
|
9 |
+
__all__ = tree_funcs + html_funcs + gen_funcs
|
utils/generation.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TextIteratorStreamer
|
2 |
+
from threading import Thread
|
3 |
+
from utils.tree_utils import get_docstrings, grab_before_comments
|
4 |
+
|
5 |
+
def combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
6 |
+
"""
|
7 |
+
Combines the generation kwargs into a single dict.
|
8 |
+
"""
|
9 |
+
gen_kwargs = {}
|
10 |
+
gen_kwargs["temperature"] = temperature
|
11 |
+
gen_kwargs["max_new_tokens"] = max_new_tokens
|
12 |
+
gen_kwargs["top_p"] = top_p
|
13 |
+
gen_kwargs["repetition_penalty"] = repetition_penalty
|
14 |
+
return gen_kwargs
|
15 |
+
|
16 |
+
|
17 |
+
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
18 |
+
"""
|
19 |
+
Text generation function
|
20 |
+
Args:
|
21 |
+
prompt (str): The context to start generation from.
|
22 |
+
pipe (Pipeline): The pipeline to use for generation.
|
23 |
+
gen_kwargs (dict): The generation kwargs.
|
24 |
+
Returns:
|
25 |
+
str: The generated text. (it iterates over time)
|
26 |
+
"""
|
27 |
+
# Tokenize the model_context
|
28 |
+
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
29 |
+
|
30 |
+
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
31 |
+
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
32 |
+
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
|
33 |
+
generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
|
34 |
+
t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
35 |
+
t.start()
|
36 |
+
|
37 |
+
# Pull the generated text from the streamer, and update the model output.
|
38 |
+
model_output = ""
|
39 |
+
for new_text in streamer:
|
40 |
+
# print("step", end="")
|
41 |
+
model_output += new_text
|
42 |
+
yield model_output
|
43 |
+
streamer.on_finalized_text("stream reached the end.")
|
44 |
+
return model_output #is this ever reached?
|
45 |
+
|
46 |
+
def construct_model_context(func_node, prompt="") -> str:
|
47 |
+
"""
|
48 |
+
Constructs the model context from a function node.
|
49 |
+
"""
|
50 |
+
model_context = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
|
51 |
+
docstring = get_docstrings(func_node) #might be empty?
|
52 |
+
if docstring:
|
53 |
+
model_context = model_context + "\n" + docstring
|
54 |
+
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
55 |
+
if prompt != "":
|
56 |
+
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
57 |
+
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
58 |
+
return model_context
|