Skip to content

Commit 583570b

Browse files
author
Cedric Richter
committed
Add cleaning script + Improve AST parsing
1 parent 942980f commit 583570b

File tree

4 files changed

+153
-4
lines changed

4 files changed

+153
-4
lines changed

code_diff/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def difference(source, target, lang = "guess", **kwargs):
1515
source_ast = parse_ast(source, lang = lang, **kwargs)
1616
target_ast = parse_ast(target, lang = lang, **kwargs)
1717

18+
if source_ast is None or target_ast is None:
19+
raise ValueError("Source / Target AST seems to be empty: %s" % source)
20+
1821
# Concretize Diff
1922
source_ast, target_ast = diff_search(source_ast, target_ast)
2023

code_diff/ast.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,100 @@ def __call__(self, tokens):
164164

165165
return self.root_node
166166

167+
168+
169+
class BottomUpParser:
170+
171+
def __init__(self, create_node_fn):
172+
173+
self.create_node_fn = create_node_fn
174+
self.waitlist = [] # Invariant: All children have been processed
175+
self.open_index = {}
176+
self.node_index = {} # Nodes that have been processed
177+
178+
def _should_ignore(self, node):
179+
return node.type == "comment"
180+
181+
def _add_to_waitlist(self, node):
182+
if self._should_ignore(node): return
183+
184+
node_key = _node_key(node)
185+
186+
if node_key not in self.node_index and node_key not in self.open_index:
187+
self.open_index[node_key] = node
188+
self.waitlist.append(node)
189+
190+
191+
def _init_lists(self, tokens):
192+
193+
for token in tokens:
194+
if hasattr(token, 'ast_node'):
195+
ast_node = token.ast_node
196+
if self._should_ignore(ast_node): continue
197+
self.open_index[_node_key(ast_node)] = ast_node
198+
self._create_node(ast_node, token.text)
199+
200+
if ast_node is None: return
201+
202+
# Get to root
203+
root = ast_node
204+
while root.parent is not None:
205+
root = root.parent
206+
207+
self._open_descandents(root)
208+
209+
return root
210+
211+
def _open_descandents(self, node):
212+
213+
queue = [node]
214+
while len(queue) > 0:
215+
current_node = queue.pop(0)
216+
217+
has_opened = False
218+
for child in current_node.children:
219+
if _node_key(child) not in self.node_index:
220+
has_opened = True
221+
queue.append(child)
222+
223+
if not has_opened:
224+
self._add_to_waitlist(current_node)
225+
226+
227+
def _open_parent(self, ast_node):
228+
parent = ast_node.parent
229+
230+
if all(_node_key(c) in self.node_index for c in parent.children if not self._should_ignore(c)):
231+
self._add_to_waitlist(parent)
232+
233+
def _create_node(self, ast_node, text = None):
234+
235+
node_key = _node_key(ast_node)
236+
children = [self.node_index[_node_key(c)] for c in ast_node.children
237+
if _node_key(c) in self.node_index]
238+
239+
position = (ast_node.start_point, ast_node.end_point)
240+
current_node = self.create_node_fn(ast_node.type, children, text = text, position = position)
241+
current_node.backend = ast_node
242+
243+
self.node_index[node_key] = current_node
244+
del self.open_index[node_key]
245+
246+
if ast_node.parent: self._open_parent(ast_node)
247+
248+
249+
def __call__(self, tokens):
250+
root_node = self._init_lists(tokens)
251+
252+
while len(self.waitlist) > 0:
253+
self._create_node(self.waitlist.pop(0))
254+
255+
if _node_key(root_node) not in self.node_index:
256+
return None
257+
258+
return self.node_index[_node_key(root_node)]
259+
260+
167261

168262

169263
# Interface ----------------------------------------------------------------
@@ -177,4 +271,4 @@ def parse_ast(source_code, lang = "guess", **kwargs):
177271

178272
ast_tokens = ct.tokenize(source_code, **kwargs)
179273

180-
return TokensToAST(default_create_node)(ast_tokens)
274+
return BottomUpParser(default_create_node)(ast_tokens)

code_diff/diff_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,55 @@ def parse_hunks(diff):
8686

8787
return hunks
8888

89+
90+
# Diff cleaning --------------------------------
91+
92+
def _has_incomplete_comment(lines):
93+
is_incomplete2 = False
94+
is_incomplete1 = False
95+
96+
for line in lines:
97+
count2 = line.count("\"\"\"")
98+
if count2 % 2 == 1: is_incomplete2 = not is_incomplete2
99+
100+
count1 = line.count("\'\'\'")
101+
if count1 % 2 == 1: is_incomplete1 = not is_incomplete1
102+
103+
return is_incomplete1 or is_incomplete2
104+
105+
106+
def _determine_incomplete_comment(lines):
107+
last_incomplete2 = -1
108+
last_incomplete1 = -1
109+
110+
for i, line in enumerate(lines):
111+
count2 = line.count("\"\"\"")
112+
if count2 % 2 == 1:
113+
last_incomplete2 = i if last_incomplete2 == -1 else -1
114+
115+
count1 = line.count("\'\'\'")
116+
if count1 % 2 == 1:
117+
last_incomplete1 = i if last_incomplete1 == -1 else -1
118+
119+
assert last_incomplete1 != -1 or last_incomplete2 != -1
120+
121+
last_incomplete = last_incomplete2 if last_incomplete2 != -1 else last_incomplete1
122+
123+
dist_to_end = len(lines) - last_incomplete
124+
125+
if last_incomplete < dist_to_end:
126+
return last_incomplete + 1, len(lines)
127+
else:
128+
return 0, last_incomplete
129+
130+
131+
def clean_hunk(hunk):
132+
if not _has_incomplete_comment(hunk.lines): return hunk
133+
start, end = _determine_incomplete_comment(hunk.lines)
134+
135+
new_lines = hunk.lines[start:end]
136+
added_lines = [l - start for l in hunk.added_lines if l >= start and l < end]
137+
rm_lines = [l - start for l in hunk.rm_lines if l >= start and l < end]
138+
139+
return Hunk(new_lines, added_lines, rm_lines)
140+

code_diff/sstubs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class SStubPattern(Enum):
3838
ADD_ATTRIBUTE_ACCESS = 21
3939

4040
# If condition
41-
MORE_SPECIFIC_IF = 21
42-
LESS_SPECIFIC_IF = 22
41+
MORE_SPECIFIC_IF = 22
42+
LESS_SPECIFIC_IF = 23
4343

4444

4545
# SStub classification -------------------------------
@@ -143,7 +143,7 @@ def _to_plain_constant(text):
143143

144144

145145
def change_constant_type(source_ast, target_ast):
146-
146+
147147
if source_ast.type == "identifier": return False
148148
if target_ast.type == "identifier": return False
149149

0 commit comments

Comments
 (0)