Skip to content

Commit 25c172a

Browse files
committed
Isomorphic node mapping for GumTree algorithm
1 parent d413b08 commit 25c172a

File tree

5 files changed

+307
-7
lines changed

5 files changed

+307
-7
lines changed

code_diff/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from code_tokenize.config import load_from_lang_config
22
from code_tokenize.tokens import match_type
33

4-
from .ast import parse_ast
5-
from .utils import cached_property
6-
from .sstubs import SStubPattern, classify_sstub
4+
from .ast import parse_ast
5+
from .utils import cached_property
6+
from .sstubs import SStubPattern, classify_sstub
7+
from .gumtree import compute_edit_script
78

89

910
# Main method --------------------------------------------------------
@@ -84,13 +85,13 @@ def sstub_pattern(self):
8485
return classify_sstub(*diff_search(self.source_ast, self.target_ast))
8586

8687
def edit_script(self):
87-
pass
88+
return compute_edit_script(self.source_ast, self.target_ast)
8889

8990
def __repr__(self):
9091
return "%s -> %s" % (self.source_text, self.target_text)
9192

9293

93-
94+
9495

9596
# AST Utils -----------------------------------------------------------
9697

code_diff/ast.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ def __init__(self, type, text = None, parent = None, children = None):
1717

1818
# Tree based attributes
1919
self.subtree_hash = None
20-
self.subtree_height = 0
20+
self.subtree_height = 1
2121
self.subtree_weight = 1
2222

2323
def isomorph(self, other):
2424
return ((self.subtree_hash, self.type, self.subtree_height, self.subtree_weight) ==
2525
(other.subtree_hash, other.type, other.subtree_height, other.subtree_weight))
2626

27+
def descandents(self):
28+
return (t for t in self if t != self)
2729

2830
def sexp(self):
2931
name = self.text if self.text is not None else self.type
@@ -39,6 +41,16 @@ def sexp(self):
3941

4042
return "%s {\n%s\n}" % (name, " ".join(child_sexp))
4143

44+
def __iter__(self):
45+
def _self_bfs_search():
46+
queue = [self]
47+
while len(queue) > 0:
48+
current = queue.pop(0)
49+
yield current
50+
queue.extend(current.children)
51+
52+
return _self_bfs_search()
53+
4254
def __repr__(self):
4355
attrs = {"type": self.type, "text": self.text}
4456
return "ASTNode(%s)" % (", ".join(["%s=%s" % (k, v) for k, v in attrs.items() if v is not None]))
@@ -48,7 +60,7 @@ def default_create_node(type, children, text = None):
4860
new_node = ASTNode(type, text = text, children = children)
4961

5062
# Subtree metrics
51-
height = 0
63+
height = 1
5264
weight = 1
5365
hash_str = []
5466

code_diff/gumtree/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .isomap import gumtree_isomap
2+
3+
# Edit script ----------------------------------------------------------------
4+
5+
def compute_edit_script(source_ast, target_ast):
6+
7+
isomap = gumtree_isomap(source_ast, target_ast)
8+
print(isomap)

code_diff/gumtree/isomap.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
2+
import heapq
3+
import itertools
4+
5+
from collections import defaultdict
6+
7+
# API method ----------------------------------------------------------------
8+
9+
def gumtree_isomap(source_ast, target_ast, min_height = 1):
10+
11+
isomorphic_mapping = NodeMapping()
12+
candidate_mapping = NodeMapping()
13+
14+
source_index = _index_iso_nodes(source_ast)
15+
target_index = _index_iso_nodes(target_ast)
16+
17+
source_open = HeightPriorityHeap(source_ast)
18+
target_open = HeightPriorityHeap(target_ast)
19+
20+
while max(source_open.max(), target_open.max()) > min_height:
21+
22+
if source_open.max() > target_open.max():
23+
for c in list(source_open.pop()):
24+
_open_node(source_open, c)
25+
continue
26+
27+
if source_open.max() < target_open.max():
28+
for c in list(target_open.pop()):
29+
_open_node(target_open, c)
30+
continue
31+
32+
source_candidates, target_candidates = list(source_open.pop()), list(target_open.pop())
33+
34+
for source_node, target_node in itertools.product(source_candidates, target_candidates):
35+
# Source node and Target node have the same height
36+
# Check if source node is isomorph to target node
37+
38+
if source_node.isomorph(target_node):
39+
# Check if there exists more candidates
40+
if (source_index[source_node] > 1
41+
or target_index[target_node] > 1):
42+
candidate_mapping.add(source_node, target_node)
43+
else:
44+
# We can savely map both nodes and all descandents
45+
_map_recursively(isomorphic_mapping, source_node, target_node)
46+
47+
# Open all unmapped nodes
48+
for source_node in source_candidates:
49+
if ((source_node, None) not in isomorphic_mapping
50+
and (source_node, None) not in candidate_mapping):
51+
_open_node(source_open, source_node)
52+
53+
for target_node in target_candidates:
54+
if ((None, target_node) not in isomorphic_mapping
55+
and (None, target_node) not in candidate_mapping):
56+
_open_node(target_open, target_node)
57+
58+
# Select the heuristically best mapping for all isomorphic pairs
59+
selection_heuristic = create_default_heuristic(isomorphic_mapping)
60+
for source_node, target_node in _select_candidates(candidate_mapping, selection_heuristic):
61+
_map_recursively(isomorphic_mapping, source_node, target_node)
62+
63+
return isomorphic_mapping
64+
65+
66+
# Collections ----------------------------------------------------------------
67+
68+
class NodeMapping:
69+
70+
def __init__(self):
71+
self._src_to_dst = defaultdict(set)
72+
self._dst_to_src = defaultdict(set)
73+
self._length = 0
74+
75+
def __getitem__(self, key):
76+
if not isinstance(key, tuple): key = (key, None)
77+
78+
src_key, dst_key = key
79+
80+
if src_key is not None and dst_key is not None:
81+
return dst_key in self._src_to_dst[src_key]
82+
83+
if src_key is None and dst_key is None:
84+
return self.__iter__()
85+
86+
if src_key is None:
87+
return ((src, dst_key) for src in self._dst_to_src[dst_key])
88+
89+
if dst_key is None:
90+
return ((src_key, dst) for dst in self._src_to_dst[src_key])
91+
92+
def __iter__(self):
93+
94+
def _iter_maps():
95+
for k, V in self._src_to_dst.items():
96+
for v in V: yield (k, v)
97+
98+
return _iter_maps()
99+
100+
def __contains__(self, key):
101+
if not isinstance(key, tuple): key = (key, None)
102+
103+
src_key, dst_key = key
104+
105+
if src_key is not None and dst_key is not None:
106+
return self[src_key, dst_key]
107+
108+
return next(self[src_key, dst_key], None) is not None
109+
110+
def __len__(self):
111+
return self._length
112+
113+
def add(self, src, dst):
114+
if not self[src, dst]:
115+
self._src_to_dst[src].add(dst)
116+
self._dst_to_src[dst].add(src)
117+
self._length += 1
118+
119+
def __copy__(self):
120+
output = NodeMapping()
121+
122+
for a, b in self:
123+
output.add(a, b)
124+
125+
return output
126+
127+
def __str__(self):
128+
approx_str = []
129+
130+
for src, dst in self:
131+
approx_str.append("%s ≈ %s" % (str(src), str(dst)))
132+
133+
return "\n".join(approx_str)
134+
135+
136+
class NodeCounter:
137+
138+
def __init__(self):
139+
self._counter = defaultdict(int)
140+
141+
def _node_key(self, node):
142+
return (node.subtree_hash, node.subtree_weight)
143+
144+
def __getitem__(self, node):
145+
return self._counter[self._node_key(node)]
146+
147+
def __setitem__(self, node, value):
148+
self._counter[self._node_key(node)] = value
149+
150+
151+
class HeightPriorityHeap:
152+
153+
def __init__(self, start_node = None):
154+
self._heap = []
155+
self.element_count = 0
156+
157+
if start_node is not None:
158+
self.push(start_node)
159+
160+
def __len__(self):
161+
return len(self._heap)
162+
163+
def push(self, x, seed = 0):
164+
try:
165+
heapq.heappush(self._heap, (-x.subtree_height, x.subtree_hash, self.element_count, seed, x))
166+
self.element_count += 1
167+
except TypeError:
168+
# Typically the type error occurs if we compare with the last element in tuple (Node)
169+
# If this happens the node is already contained in the heap and we skip this push
170+
return
171+
172+
def max(self):
173+
if len(self) == 0: return 0
174+
return -self._heap[0][0]
175+
176+
def pop(self):
177+
current_head = self.max()
178+
179+
while len(self) > 0 and self.max() == current_head:
180+
yield heapq.heappop(self._heap)[-1]
181+
182+
# Helper methods -----------------------------------------------------------
183+
184+
def _index_iso_nodes(ast):
185+
result = NodeCounter()
186+
for node in ast: result[node] += 1
187+
188+
return result
189+
190+
def _open_node(heap, node):
191+
for n, child in enumerate(node.children):
192+
heap.push(child, seed = n)
193+
194+
def _map_recursively(mapping, source_node, target_node):
195+
mapping.add(source_node, target_node)
196+
197+
for i, source_child in enumerate(source_node.children):
198+
target_child = target_node.children[i]
199+
assert source_node.type == target_node.type
200+
201+
_map_recursively(mapping, source_child, target_child)
202+
203+
# Heuristic selection ----------------------------------------------------------------
204+
205+
def _subtree_dice(A, B, mapping):
206+
207+
if A is None or B is None:
208+
return 1.0 if all(x is None for x in [A, B]) else 0.0
209+
210+
DA, DB = set(A.descandents()), set(B.descandents())
211+
212+
norm = len(DA) + len(DB)
213+
214+
if norm == 0: return 1.0
215+
216+
mapped = defaultdict(set)
217+
for a, b in mapping: mapped[a].add(b)
218+
219+
mapped_children = set(m for t1 in DA if t1 in mapped for m in mapped[t1])
220+
dice_score = len(set.intersection(mapped_children, DB))
221+
222+
return 2 * dice_score / norm
223+
224+
225+
def create_default_heuristic(isomorphic_mapping):
226+
227+
def _heuristic(source_node, target_node):
228+
dice = _subtree_dice(source_node, target_node, isomorphic_mapping)
229+
return dice
230+
231+
return _heuristic
232+
233+
234+
def _select_candidates(candidate_mapping, heuristic = None):
235+
if len(candidate_mapping) == 0: return
236+
237+
candidate_pairs = [(s, t) for s, t in candidate_mapping]
238+
239+
if heuristic is not None:
240+
candidate_pairs = sorted(candidate_pairs,
241+
key=lambda p: heuristic(*p),
242+
reverse=True)
243+
244+
source_seen = set()
245+
target_seen = set()
246+
247+
while len(candidate_pairs) > 0:
248+
source_node, target_node = candidate_pairs.pop(0)
249+
250+
if source_node in source_seen:
251+
continue
252+
source_seen.add(source_node)
253+
254+
if target_node in target_seen:
255+
continue
256+
target_seen.add(target_node)
257+
258+
yield source_node, target_node
259+

code_diff/gumtree/ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
3+
class ASTOperation:
4+
pass
5+
6+
7+
class UpdateOperation(ASTOperation):
8+
pass
9+
10+
11+
class InsertOperation(ASTOperation):
12+
pass
13+
14+
15+
class DeleteOperation(ASTOperation):
16+
pass
17+
18+
19+
class MoveOperation(ASTOperation):
20+
pass

0 commit comments

Comments
 (0)