Skip to content

Commit e5c17f8

Browse files
committed
Improved parsing of diff [Edit script + SStubs]
1 parent 769bc01 commit e5c17f8

File tree

8 files changed

+355
-85
lines changed

8 files changed

+355
-85
lines changed

code_diff/__init__.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .ast import parse_ast
55
from .utils import cached_property
66
from .sstubs import SStubPattern, classify_sstub
7-
from .gumtree import compute_edit_script
7+
from .gumtree import compute_edit_script, EditScript, Update
88

99

1010
# Main method --------------------------------------------------------
@@ -84,7 +84,11 @@ def statement_diff(self):
8484

8585
return ASTDiff(self.config, source_stmt, target_stmt)
8686

87+
def root_diff(self):
88+
return ASTDiff(self.config, ast_root(self.source_ast), ast_root(self.target_ast))
89+
8790
def sstub_pattern(self):
91+
8892
if (parent_statement(self.config.statement_types, self.source_ast) is None
8993
or parent_statement(self.config.statement_types, self.target_ast) is None):
9094
return SStubPattern.NO_STMT
@@ -98,16 +102,18 @@ def edit_script(self):
98102

99103
source_ast, target_ast = self.source_ast, self.target_ast
100104

101-
def is_statement_type(node_type):
102-
return any(match_type(r, node_type) for r in self.config.statement_types)
105+
if source_ast.type == target_ast.type and len(source_ast.children) == 0 and len(target_ast.children) == 0:
106+
# Both nodes are tokens of the same type
107+
# Only an update is required
108+
return EditScript([Update(source_ast, target_ast.text)])
109+
110+
# We need a common root to add to
111+
while source_ast.type != target_ast.type:
112+
if source_ast.parent is None: break
113+
if target_ast.parent is None: break
103114

104-
if not is_statement_type(source_ast.type) or not is_statement_type(target_ast.type):
105-
# We need something where we can add to (root)
106-
if source_ast.parent is not None:
107-
source_ast = source_ast.parent
108-
109-
if target_ast.parent is not None:
110-
target_ast = target_ast.parent
115+
source_ast = source_ast.parent
116+
target_ast = target_ast.parent
111117

112118
return compute_edit_script(source_ast, target_ast)
113119

@@ -150,6 +156,15 @@ def is_statement_type(node_type):
150156
return parent_node
151157

152158

159+
def ast_root(ast):
160+
parent_node = ast
161+
162+
while parent_node.parent is not None:
163+
parent_node = parent_node.parent
164+
165+
return parent_node
166+
167+
153168
def tokenize_tree(ast):
154169
tokens = []
155170

@@ -160,3 +175,18 @@ def tokenize_tree(ast):
160175
tokens.append(tokenize_tree(child))
161176

162177
return " ".join(tokens)
178+
179+
180+
181+
def is_compatible_root(root_candidate, source_ast):
182+
return not equal_text(source_ast, root_candidate) and root_candidate.type != "block"
183+
184+
185+
def equal_text(source_ast, parent_ast):
186+
source_position = source_ast.position
187+
parent_position = parent_ast.position
188+
189+
if parent_position[0][0] < source_position[0][0]: return False
190+
if source_position[1][0] < parent_position[1][0]: return False
191+
192+
return (source_position[0][1], source_position[1][1]) == (parent_position[0][1], parent_position[1][1])

code_diff/diff_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
class Hunk:
77

8-
def __init__(self, lines, added_lines, rm_lines):
8+
def __init__(self, lines, added_lines, rm_lines, header = None):
99
self.lines = lines
1010
self.added_lines = set(added_lines)
1111
self.rm_lines = set(rm_lines)
12+
self.header = header
1213

1314

1415
@property
@@ -41,6 +42,10 @@ def before(self):
4142
return "".join(alines)
4243

4344
def __repr__(self):
45+
46+
if self.header:
47+
return self.header + "".join(self.lines)
48+
4449
return "".join(self.lines)
4550

4651

@@ -55,7 +60,7 @@ def _parse_hunk(lines, start, end):
5560
if hline.startswith("+"): added_lines.append(i)
5661
if hline.startswith("-"): rm_lines.append(i)
5762

58-
return Hunk(hunk_lines, added_lines, rm_lines)
63+
return Hunk(hunk_lines, added_lines, rm_lines, header = lines[start])
5964

6065

6166
hunk_pat = re.compile("@@ -(\d+)(,\d+)? \+(\d+)(,\d+)? @@.*")
@@ -136,5 +141,5 @@ def clean_hunk(hunk):
136141
added_lines = [l - start for l in hunk.added_lines if l >= start and l < end]
137142
rm_lines = [l - start for l in hunk.rm_lines if l >= start and l < end]
138143

139-
return Hunk(new_lines, added_lines, rm_lines)
144+
return Hunk(new_lines, added_lines, rm_lines, header = hunk.header)
140145

code_diff/gumtree/__init__.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from .editmap import gumtree_editmap
33
from .chawathe import compute_chawathe_edit_script
44
from .ops import (Update, Insert, Delete, Move)
5+
from .ops import EditScript
56
from .ops import serialize_script, deserialize_script
7+
from .ops import json_serialize, json_deserialize
68

79
# Edit script ----------------------------------------------------------------
810

@@ -27,15 +29,4 @@ def compute_edit_script(source_ast, target_ast, min_height = 1, max_size = 1000,
2729
# Update leaf ----------------------------------------------------------------
2830

2931
def _update_leaf(source_ast, target_ast):
30-
return Update(source_ast, target_ast.text)
31-
32-
33-
# Edit script ----------------------------------------------------------------
34-
35-
class EditScript(list):
36-
37-
def __init__(self, operations):
38-
super().__init__(operations)
39-
40-
def __repr__(self):
41-
return serialize_script(self, indent = 2)
32+
return Update(source_ast, target_ast.text)

code_diff/gumtree/isomap.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import heapq
33
import itertools
4+
import math
45

56
from collections import defaultdict
67

@@ -137,10 +138,25 @@ def _map_recursively(mapping, source_node, target_node):
137138
# Heuristic selection ----------------------------------------------------------------
138139

139140

141+
def source_distance(source_node, target_node):
142+
143+
max_token_mover = 1000
144+
145+
line_mover_distance = source_node.position[0][0] - target_node.position[1][0]
146+
line_mover_distance = line_mover_distance * max_token_mover
147+
148+
if line_mover_distance == 0:
149+
token_mover_distance = min(abs(source_node.position[0][1] - target_node.position[0][1]), max_token_mover - 1)
150+
line_mover_distance += token_mover_distance
151+
152+
return -line_mover_distance
153+
154+
155+
140156
def create_default_heuristic(isomorphic_mapping):
141157

142158
def _heuristic(source_node, target_node):
143-
return subtree_dice(source_node, target_node, isomorphic_mapping)
159+
return (subtree_dice(source_node, target_node, isomorphic_mapping), source_distance(source_node, target_node))
144160

145161
return _heuristic
146162

code_diff/gumtree/ops.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ class Move(EditOperation):
2525
class Delete(EditOperation):
2626
pass
2727

28+
# Edit script ----------------------------------------------------------------
29+
30+
class EditScript(list):
31+
32+
def __init__(self, operations):
33+
super().__init__(operations)
34+
35+
def __repr__(self):
36+
return serialize_script(self, indent = 2)
37+
2838

2939
# Serialization --------------------------------
3040

@@ -285,11 +295,11 @@ def json_serialize(edit_script):
285295
else: # Leaf node
286296
new_node_str = ["%s:%s" % new_node, "T"]
287297

288-
edit_ops.append([operation_name, target_node_str, new_node_str])
298+
edit_ops.append([operation_name, target_node_str, new_node_str, operation.position])
289299

290300
elif operation_name == "Move":
291301

292-
new_node_str = _serialize_node(new_node_index, operation.node)
302+
new_node_str = _json_serialize_node(new_node_index, operation.node)
293303

294304
edit_ops.append([operation_name, target_node_str, new_node_str, operation.position])
295305

@@ -299,4 +309,81 @@ def json_serialize(edit_script):
299309
return json.dumps(edit_ops)
300310

301311

302-
# Fast deserialize ----------------------------------------------------------------------
312+
# Fast deserialize ----------------------------------------------------------------------
313+
314+
def _json_deserialize_node(node_index, node_info):
315+
316+
if not isinstance(node_info, list) and node_info != "T":
317+
node_id = int(node_info[1:])
318+
return node_index[node_id]
319+
320+
node_type, position = node_info[0], node_info[1:]
321+
node_text = None
322+
323+
if ":" in node_type:
324+
node_type, node_text = node_type.split(":", 1)
325+
326+
if len(position) == 4:
327+
return DASTNode(node_type, ((position[0], position[1]), (position[2], position[3])), node_text)
328+
329+
return InsertNode(position[0], node_type, node_text)
330+
331+
332+
def _json_deserialize_node_constructor(node_index, cn_info):
333+
node_type, node_id = cn_info
334+
node_text = None
335+
336+
if ":" in node_type:
337+
node_type, node_text = node_type.split(":", 1)
338+
339+
if node_id != "T":
340+
node_id = int(node_id[1:])
341+
node_index[node_id] = InsertNode(node_id, node_type, node_text)
342+
return node_index[node_id]
343+
344+
return InsertNode(node_id, node_type, node_text)
345+
346+
347+
def _json_deserialize_update(node_index, operation):
348+
_, target, update = operation
349+
target = _json_deserialize_node(node_index, target)
350+
return Update(target, update)
351+
352+
353+
def _json_deserialize_insert(node_index, operation):
354+
_, target, new_node, position = operation
355+
target = _json_deserialize_node(node_index, target)
356+
new_node = _json_deserialize_node_constructor(node_index, new_node)
357+
358+
return Insert(target, (new_node.type, new_node.text), position, new_node.node_id)
359+
360+
361+
def _json_deserialize_delete(node_index, operation):
362+
return Delete(_json_deserialize_node(node_index, operation[1]))
363+
364+
365+
def _json_deserialize_move(node_index, operation):
366+
_, target, move_node, position = operation
367+
target = _json_deserialize_node(node_index, target)
368+
move_node = _json_deserialize_node(node_index, move_node)
369+
return Move(target, move_node, position)
370+
371+
372+
DESERIALIZE = {
373+
"Update" : _json_deserialize_update,
374+
"Insert" : _json_deserialize_insert,
375+
"Delete" : _json_deserialize_delete,
376+
"Move" : _json_deserialize_move
377+
}
378+
379+
380+
def json_deserialize(edit_json):
381+
edit_ops = json.loads(edit_json)
382+
output = []
383+
node_index = {}
384+
385+
for operation in edit_ops:
386+
operation_name = operation[0]
387+
output.append(DESERIALIZE[operation_name](node_index, operation))
388+
389+
return EditScript(output)

0 commit comments

Comments
 (0)