Spaces:
Runtime error
Runtime error
File size: 4,543 Bytes
2d141af 0907806 2d141af 27e63ab c6524f1 27e63ab 829134c 27e63ab 829134c 27e63ab abed9bd 27e63ab 2d141af 3f8d823 2d141af c6524f1 2d141af 3f8d823 2d141af 50c1955 2d141af abed9bd 50c1955 abed9bd 50c1955 abed9bd c6524f1 2d141af 3f8d823 2d141af 491ed03 2d141af 61d090a 2d141af 80cd783 491ed03 2d141af 50c1955 491ed03 50c1955 3f8d823 2d141af abed9bd 2d141af abed9bd 2d141af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import tree_sitter
from tree_sitter import Language, Parser
Language.build_library("./build/my-languages.so", ['./tree-sitter-glsl'])
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
parser = Parser()
parser.set_language(GLSL_LANGUAGE)
def replace_function(old_func_node, new_func_node):
"""
replaces the old function node with the new function node
"""
tree = give_tree(old_func_node)
old_func_start, old_func_end = node_str_idx(old_func_node)
# new_func_start, new_func_end = node_str_idx(new_func_node)
new_code = tree.text[:old_func_start].decode() + new_func_node.text.decode() + tree.text[old_func_end:].decode()
return new_code
def get_root(node):
"""
returns the root node the tree of the given node (recursively)
"""
if node.parent is None:
return node
else:
return get_root(node.parent)
def node_str_idx(node):
"""
returns the character index of start and end of a node
"""
whole_text = get_root(node).text.decode()
# start_idx = line_chr2char(whole_text, node.start_point[0], node.start_point[1])
# end_idx = line_chr2char(whole_text, node.end_point[0], node.end_point[1])
start_idx = node.start_byte #actual numbers?
end_idx = node.end_byte
return start_idx, end_idx
def give_tree(func_node):
"""
return the tree where this function node is in
"""
return parser.parse(func_node.parent.text) #really no better way?
def parse_functions(in_code):
"""
returns all functions in the code as their actual nodes.
includes any comment made directly after the function definition or diretly after #copilot trigger
"""
tree = parser.parse(bytes(in_code, encoding="utf-8"))
funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
return funcs
def get_docstrings(func_node):
"""
returns the docstring of a function node
"""
docstring = ""
for node in func_node.children:
if node.type == "comment": #comment in like the declarator
docstring += node.text.decode()
elif node.type == "compound_statement": #body below here
for body_node in node.children:
if body_node.type == "comment" or body_node.type == "{":
docstring += " " * body_node.start_point[1] #add in indentation
docstring += body_node.text.decode() + "\n"
else:
return docstring
return docstring
def full_func_head(func_node) -> str:
"""
returns function head including docstrings before any real body code
"""
cursor = func_node.child_by_field_name("body").walk()
cursor.goto_first_child()
while cursor.node.type == "comment" or cursor.node.type == "{":
last_char = cursor.node.end_byte
cursor.goto_next_sibling()
end = cursor.node.start_point
# return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])[:-(last_char)-1]
return func_node.text[:(last_char - func_node.start_byte)].decode()
def grab_before_comments(func_node):
"""
returns the comments that happen just before a function node
"""
precomment = ""
last_comment_line = 0
start_byte = func_node.start_byte
for node in func_node.parent.children: #could you optimize where to iterated from? directon?
if node.start_point[0] != last_comment_line + 1:
precomment = ""
if node.type == "comment":
if precomment == "":
start_byte = node.start_byte
precomment += node.text.decode() + "\n"
last_comment_line = node.start_point[0]
elif node == func_node:
if precomment == "":
start_byte = node.start_byte
return precomment, start_byte
return precomment, start_byte
def has_docstrings(func_node):
"""
returns whether a function node has a docstring
"""
return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node)[0] != ""
def line_chr2char(text, line_idx, chr_idx):
"""
## just use strat_byte and end_byte instead!
returns the character index at the given line and character index.
"""
lines = text.split("\n")
char_idx = 0
for i in range(line_idx):
try:
char_idx += len(lines[i]) + 1
except IndexError as e:
raise IndexError(f"{i=} of {line_idx=} does not exist in {text=}") from e
char_idx += chr_idx
return char_idx
|