Skip to content

Commit 514b9f4

Browse files
committed
Compute Editmap with APTED
1 parent 25c172a commit 514b9f4

File tree

4 files changed

+254
-91
lines changed

4 files changed

+254
-91
lines changed

code_diff/gumtree/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
from .isomap import gumtree_isomap
1+
from .isomap import gumtree_isomap
2+
from .editmap import gumtree_editmap
23

34
# Edit script ----------------------------------------------------------------
45

56
def compute_edit_script(source_ast, target_ast):
67

78
isomap = gumtree_isomap(source_ast, target_ast)
8-
print(isomap)
9+
print(isomap)
10+
print("----------------------------------------------------------------")
11+
12+
editmap = gumtree_editmap(isomap, source_ast, target_ast)
13+
14+
print(editmap)
15+

code_diff/gumtree/editmap.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from apted import APTED, Config
2+
3+
from .utils import subtree_dice, postorder_traversal
4+
5+
# Minimal edit mapping to make source isomorph to target --------------------------------
6+
7+
# We compute a mapping between source and target tree
8+
# If a source node is mapped to a target node with different label,
9+
# the source node has to be updated with the target label
10+
# If a source node is unmapped,
11+
# the source node has to be deleted
12+
# If a target node is unmapped,
13+
# the target node has to be added to the source tree
14+
#
15+
# Edits are chosen to be (approximately) minimal
16+
17+
18+
# APTED for computing a minimal edit --------------------------------
19+
20+
class APTEDConfig(Config):
21+
22+
def rename(self, node1, node2):
23+
24+
if node1.type == node2.type:
25+
return 1 if node1.text != node2.text else 0
26+
27+
return 1
28+
29+
def children(self, node):
30+
return node.children
31+
32+
33+
def _minimal_edit(isomap, source, target, max_size = 1000):
34+
if source.subtree_weight > max_size or target.subtree_weight > max_size: return
35+
36+
apted = APTED(source, target, APTEDConfig())
37+
mapping = apted.compute_edit_mapping()
38+
39+
for source_node, target_node in mapping:
40+
if source_node is None: continue
41+
if target_node is None: continue
42+
if source_node.type != target_node.type: continue
43+
44+
if (source_node, None) in isomap: continue
45+
if (None, target_node) in isomap: continue
46+
47+
yield source_node, target_node
48+
49+
50+
# Select node heuristically that is close to isomorph --------------------
51+
52+
def _select_near_candidate(source_node, mapping):
53+
54+
dst_seeds = []
55+
56+
for src in source_node.descandents():
57+
for _, dst in mapping[src, None]:
58+
dst_seeds.append(dst)
59+
60+
candidates = []
61+
seen = set()
62+
63+
for dst in dst_seeds:
64+
while dst.parent is not None:
65+
parent = dst.parent
66+
if parent in seen: break
67+
seen.add(parent)
68+
69+
if (parent.type == source_node.type
70+
and parent.parent is not None
71+
and (None, parent) not in mapping):
72+
candidates.append(parent)
73+
dst = parent
74+
75+
if len(candidates) == 0: return None, 0.0
76+
77+
candidates = [(x, subtree_dice(source_node, x, mapping)) for x in candidates]
78+
79+
return max(candidates, key=lambda x: x[1])
80+
81+
82+
# Gumtree Edit Mapping ---------------------------------------------------
83+
84+
def gumtree_editmap(isomap, source, target, max_size = 1000, min_dice = 0.5):
85+
# Caution: This method does change the isomap
86+
if len(isomap) == 0: return isomap
87+
88+
for source_node in postorder_traversal(source):
89+
90+
if source_node == source: # source_node is root
91+
isomap.add(source_node, target)
92+
93+
for s, t in _minimal_edit(isomap, source_node, target, max_size):
94+
isomap.add(s, t)
95+
96+
break
97+
98+
if len(source_node.children) == 0: continue # source_node is leaf
99+
if (source_node, None) in isomap: continue # source_node is now mapped
100+
101+
target_node, dice = _select_near_candidate(source_node, isomap)
102+
103+
if target_node is None or dice <= min_dice: continue
104+
105+
for s, t in _minimal_edit(isomap, source_node, target_node, max_size):
106+
isomap.add(s, t)
107+
isomap.add(source_node, target_node)
108+
109+
return isomap

code_diff/gumtree/isomap.py

Lines changed: 3 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from collections import defaultdict
66

7+
from .utils import NodeMapping, subtree_dice
8+
79
# API method ----------------------------------------------------------------
810

911
def gumtree_isomap(source_ast, target_ast, min_height = 1):
@@ -65,74 +67,6 @@ def gumtree_isomap(source_ast, target_ast, min_height = 1):
6567

6668
# Collections ----------------------------------------------------------------
6769

68-
class NodeMapping:
69-
70-
def __init__(self):
71-
self._src_to_dst = defaultdict(set)
72-
self._dst_to_src = defaultdict(set)
73-
self._length = 0
74-
75-
def __getitem__(self, key):
76-
if not isinstance(key, tuple): key = (key, None)
77-
78-
src_key, dst_key = key
79-
80-
if src_key is not None and dst_key is not None:
81-
return dst_key in self._src_to_dst[src_key]
82-
83-
if src_key is None and dst_key is None:
84-
return self.__iter__()
85-
86-
if src_key is None:
87-
return ((src, dst_key) for src in self._dst_to_src[dst_key])
88-
89-
if dst_key is None:
90-
return ((src_key, dst) for dst in self._src_to_dst[src_key])
91-
92-
def __iter__(self):
93-
94-
def _iter_maps():
95-
for k, V in self._src_to_dst.items():
96-
for v in V: yield (k, v)
97-
98-
return _iter_maps()
99-
100-
def __contains__(self, key):
101-
if not isinstance(key, tuple): key = (key, None)
102-
103-
src_key, dst_key = key
104-
105-
if src_key is not None and dst_key is not None:
106-
return self[src_key, dst_key]
107-
108-
return next(self[src_key, dst_key], None) is not None
109-
110-
def __len__(self):
111-
return self._length
112-
113-
def add(self, src, dst):
114-
if not self[src, dst]:
115-
self._src_to_dst[src].add(dst)
116-
self._dst_to_src[dst].add(src)
117-
self._length += 1
118-
119-
def __copy__(self):
120-
output = NodeMapping()
121-
122-
for a, b in self:
123-
output.add(a, b)
124-
125-
return output
126-
127-
def __str__(self):
128-
approx_str = []
129-
130-
for src, dst in self:
131-
approx_str.append("%s ≈ %s" % (str(src), str(dst)))
132-
133-
return "\n".join(approx_str)
134-
135-
13670
class NodeCounter:
13771

13872
def __init__(self):
@@ -202,31 +136,11 @@ def _map_recursively(mapping, source_node, target_node):
202136

203137
# Heuristic selection ----------------------------------------------------------------
204138

205-
def _subtree_dice(A, B, mapping):
206-
207-
if A is None or B is None:
208-
return 1.0 if all(x is None for x in [A, B]) else 0.0
209-
210-
DA, DB = set(A.descandents()), set(B.descandents())
211-
212-
norm = len(DA) + len(DB)
213-
214-
if norm == 0: return 1.0
215-
216-
mapped = defaultdict(set)
217-
for a, b in mapping: mapped[a].add(b)
218-
219-
mapped_children = set(m for t1 in DA if t1 in mapped for m in mapped[t1])
220-
dice_score = len(set.intersection(mapped_children, DB))
221-
222-
return 2 * dice_score / norm
223-
224139

225140
def create_default_heuristic(isomorphic_mapping):
226141

227142
def _heuristic(source_node, target_node):
228-
dice = _subtree_dice(source_node, target_node, isomorphic_mapping)
229-
return dice
143+
return subtree_dice(source_node, target_node, isomorphic_mapping)
230144

231145
return _heuristic
232146

code_diff/gumtree/utils.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from collections import defaultdict
2+
3+
# Collections -------------------------------------------------------------------
4+
5+
class NodeMapping:
6+
7+
def __init__(self):
8+
self._src_to_dst = defaultdict(set)
9+
self._dst_to_src = defaultdict(set)
10+
self._length = 0
11+
12+
def __getitem__(self, key):
13+
if not isinstance(key, tuple): key = (key, None)
14+
15+
src_key, dst_key = key
16+
17+
if src_key is not None and dst_key is not None:
18+
return dst_key in self._src_to_dst[src_key]
19+
20+
if src_key is None and dst_key is None:
21+
return self.__iter__()
22+
23+
if src_key is None:
24+
return ((src, dst_key) for src in self._dst_to_src[dst_key])
25+
26+
if dst_key is None:
27+
return ((src_key, dst) for dst in self._src_to_dst[src_key])
28+
29+
def __iter__(self):
30+
31+
def _iter_maps():
32+
for k, V in self._src_to_dst.items():
33+
for v in V: yield (k, v)
34+
35+
return _iter_maps()
36+
37+
def __contains__(self, key):
38+
if not isinstance(key, tuple): key = (key, None)
39+
40+
src_key, dst_key = key
41+
42+
if src_key is not None and dst_key is not None:
43+
return self[src_key, dst_key]
44+
45+
return next(self[src_key, dst_key], None) is not None
46+
47+
def __len__(self):
48+
return self._length
49+
50+
def add(self, src, dst):
51+
if not self[src, dst]:
52+
self._src_to_dst[src].add(dst)
53+
self._dst_to_src[dst].add(src)
54+
self._length += 1
55+
56+
def __copy__(self):
57+
output = NodeMapping()
58+
59+
for a, b in self:
60+
output.add(a, b)
61+
62+
return output
63+
64+
def __str__(self):
65+
approx_str = []
66+
67+
for src, dst in self:
68+
approx_str.append("%s ≈ %s" % (str(src), str(dst)))
69+
70+
return "\n".join(approx_str)
71+
72+
73+
# Tree heuristic ----------------------------------------------------------------
74+
75+
def subtree_dice(A, B, mapping):
76+
77+
if A is None or B is None:
78+
return 1.0 if all(x is None for x in [A, B]) else 0.0
79+
80+
DA, DB = set(A.descandents()), set(B.descandents())
81+
82+
norm = len(DA) + len(DB)
83+
84+
if norm == 0: return 1.0
85+
86+
mapped = defaultdict(set)
87+
for a, b in mapping: mapped[a].add(b)
88+
89+
mapped_children = set(m for t1 in DA if t1 in mapped for m in mapped[t1])
90+
dice_score = len(set.intersection(mapped_children, DB))
91+
92+
return 2 * dice_score / norm
93+
94+
95+
# Tree traversal ----------------------------------------------------------------
96+
97+
def bfs_traversal(tree):
98+
queue = [tree]
99+
100+
while len(queue) > 0:
101+
node = queue.pop(0)
102+
103+
yield node
104+
105+
for c in node.children:
106+
queue.append(c)
107+
108+
109+
def dfs_traversal(tree):
110+
stack = [tree]
111+
112+
while len(stack) > 0:
113+
node = stack.pop(-1)
114+
115+
yield node
116+
117+
for c in node.children:
118+
stack.append(c)
119+
120+
121+
def postorder_traversal(tree):
122+
123+
stack = [(tree, 0)]
124+
125+
while len(stack) > 0:
126+
node, ix = stack.pop(-1)
127+
128+
if ix >= len(node.children):
129+
yield node
130+
else:
131+
stack.append((node, ix + 1))
132+
stack.append((node.children[ix], 0))
133+

0 commit comments

Comments
 (0)