Skip to content

Commit 1fd4b97

Browse files
nandikaKjoernhees
authored andcommitted
WIP: narrow deep paths mutation
1 parent 5942bb5 commit 1fd4b97

File tree

5 files changed

+293
-4
lines changed

5 files changed

+293
-4
lines changed

config/defaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@
7878
MUTPB_FV_SAMPLE_MAXN = 32 # max n of instantiations to sample from top k
7979
MUTPB_FV_QUERY_LIMIT = 256 # SPARQL query limit for the top k instantiations
8080
MUTPB_SP = 0.05 # prob to simplify pattern (warning: can restrict exploration)
81+
MUTPB_DN = 0.05 # prob to try a deep and narrow paths mutation
82+
MUTPB_DN_MIN_LEN = 2 # minimum length of the deep and narrow paths
83+
MUTPB_DN_MAX_LEN = 10 # absolute max of path length if not stopped by term_pb
84+
MUTPB_DN_TERM_PB = 0.3 # prob to terminate node expansion each step > min_len
85+
MUTPB_DN_FILTER_NODE_COUNT = 10
86+
MUTPB_DN_FILTER_EDGE_COUNT = 1
87+
MUTPB_DN_QUERY_LIMIT = 32
8188

8289
# for import in helpers and __init__
8390
__all__ = [_v for _v in globals().keys() if _v.isupper()]

gp_learner.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from gp_query import predict_query
4949
from gp_query import query_time_hard_exceeded
5050
from gp_query import query_time_soft_exceeded
51+
from gp_query import variable_substitution_deep_narrow_mut_query
5152
from gp_query import variable_substitution_query
5253
from graph_pattern import canonicalize
5354
from graph_pattern import gen_random_var
@@ -653,6 +654,105 @@ def mutate_fix_var(
653654
return res
654655

655656

657+
def _mutate_deep_narrow_path_helper(
658+
sparql,
659+
timeout,
660+
gtp_scores,
661+
child,
662+
edge_var,
663+
node_var,
664+
gtp_sample_n=config.MUTPB_FV_RGTP_SAMPLE_N,
665+
limit_res=config.MUTPB_DN_QUERY_LIMIT,
666+
sample_n=config.MUTPB_FV_SAMPLE_MAXN,
667+
):
668+
assert isinstance(child, GraphPattern)
669+
assert isinstance(gtp_scores, GTPScores)
670+
671+
# The further we get, the less gtps are remaining. Sampling too many (all)
672+
# of them might hurt as common substitutions (> limit ones) which are dead
673+
# ends could cover less common ones that could actually help
674+
gtp_sample_n = min(gtp_sample_n, int(gtp_scores.remaining_gain))
675+
gtp_sample_n = random.randint(1, gtp_sample_n)
676+
677+
ground_truth_pairs = gtp_scores.remaining_gain_sample_gtps(
678+
n=gtp_sample_n)
679+
t, substitution_counts = variable_substitution_deep_narrow_mut_query(
680+
sparql, timeout, child, edge_var, node_var, ground_truth_pairs,
681+
limit_res)
682+
edge_count, node_sum_count = substitution_counts
683+
if not node_sum_count:
684+
# the current pattern is unfit, as we can't find anything fulfilling it
685+
logger.debug("tried to fix a var %s without result:\n%s"
686+
"seems as if the pattern can't be fulfilled!",
687+
edge_var, child.to_sparql_select_query())
688+
fixed = False
689+
return child, fixed
690+
mutate_fix_var_filter(node_sum_count)
691+
mutate_fix_var_filter(edge_count)
692+
if not node_sum_count:
693+
# could have happened that we removed the only possible substitution
694+
fixed = False
695+
return child, fixed
696+
697+
prio = Counter()
698+
for edge, node_sum in node_sum_count.items():
699+
ec = edge_count[edge]
700+
prio[edge] = ec / (node_sum / ec) # ec / AVG degree
701+
# randomly pick n of the substitutions with a prob ~ to their counts
702+
edges, prios = zip(*prio.most_common())
703+
704+
substs = sample_from_list(edges, prios, sample_n)
705+
706+
logger.info(
707+
'fixed variable %s in %sto:\n %s\n<%d out of:\n%s\n',
708+
edge_var.n3(),
709+
child,
710+
'\n '.join([subst.n3() for subst in substs]),
711+
sample_n,
712+
'\n'.join([
713+
' %.3f: %s' % (c, v.n3()) for v, c in prio.most_common()]),
714+
)
715+
fixed = True
716+
orig_child = child
717+
children = [
718+
GraphPattern(child, mapping={edge_var: subst})
719+
for subst in substs
720+
]
721+
children = [
722+
c if fit_to_live(c) else orig_child
723+
for c in children
724+
]
725+
if children:
726+
child = random.choice(list(children))
727+
return child, fixed
728+
729+
730+
def mutate_deep_narrow_path(
731+
child, sparql, timeout, gtp_scores,
732+
min_len=config.MUTPB_DN_MIN_LEN,
733+
max_len=config.MUTPB_DN_MAX_LEN,
734+
term_pb=config.MUTPB_DN_TERM_PB,
735+
):
736+
assert isinstance(child, GraphPattern)
737+
nodes = list(child.nodes)
738+
start_node = random.choice(nodes)
739+
# target_nodes = set(nodes) - {start_node}
740+
gp = child
741+
hop = 0
742+
while True:
743+
if hop >= min_len and random.random() < term_pb:
744+
break
745+
if hop >= max_len:
746+
break
747+
hop += 1
748+
new_triple, var_node, var_edge = _mutate_expand_node_helper(start_node)
749+
gp += [new_triple]
750+
gp, fixed = _mutate_deep_narrow_path_helper(
751+
sparql, timeout, gtp_scores, gp, var_edge, var_node)
752+
start_node = var_node
753+
return gp
754+
755+
656756
def mutate_simplify_pattern(gp):
657757
if len(gp) < 2:
658758
return gp
@@ -757,6 +857,7 @@ def mutate(
757857
pb_dt=config.MUTPB_DT,
758858
pb_en=config.MUTPB_EN,
759859
pb_fv=config.MUTPB_FV,
860+
pb_dn=config.MUTPB_DN,
760861
pb_id=config.MUTPB_ID,
761862
pb_iv=config.MUTPB_IV,
762863
pb_mv=config.MUTPB_MV,
@@ -796,15 +897,15 @@ def mutate(
796897
if random.random() < pb_sp:
797898
child = mutate_simplify_pattern(child)
798899

900+
if random.random() < pb_dn:
901+
child = mutate_deep_narrow_path(child, sparql, timeout, gtp_scores)
902+
799903
if random.random() < pb_fv:
800904
child = canonicalize(child)
801905
children = mutate_fix_var(sparql, timeout, gtp_scores, child)
802906
else:
803907
children = [child]
804908

805-
806-
# TODO: deep & narrow paths mutation
807-
808909
children = {
809910
c if fit_to_live(c) else orig_child
810911
for c in children

gp_query.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from graph_pattern import TARGET_VAR
3333
from graph_pattern import ASK_VAR
3434
from graph_pattern import COUNT_VAR
35+
from graph_pattern import NODE_VAR_SUM
36+
from graph_pattern import EDGE_VAR_COUNT
3537
from utils import exception_stack_catcher
3638
from utils import sparql_json_result_bindings_to_rdflib
3739
from utils import timer
@@ -279,7 +281,6 @@ def _combined_chunk_res(q_res, _vars, _ret_val_mapping):
279281
return chunk_res
280282

281283

282-
283284
def count_query(sparql, timeout, graph_pattern, source=None,
284285
**kwds):
285286
assert isinstance(graph_pattern, GraphPattern)
@@ -457,6 +458,68 @@ def _var_subst_res_update(res, update, **_):
457458
res += update
458459

459460

461+
def variable_substitution_deep_narrow_mut_query(
462+
sparql, timeout, graph_pattern, edge_var, node_var,
463+
source_target_pairs, limit_res, batch_size=config.BATCH_SIZE):
464+
_vars, _values, _ret_val_mapping = _get_vars_values_mapping(
465+
graph_pattern, source_target_pairs)
466+
_edge_var_node_var_and_vars = (edge_var, node_var, _vars)
467+
return _multi_query(
468+
sparql, timeout, graph_pattern, source_target_pairs, batch_size,
469+
_edge_var_node_var_and_vars, _values, _ret_val_mapping,
470+
_var_subst_dnp_res_init, _var_subst_dnp_chunk_q,
471+
_var_subst_dnp_chunk_result_ext,
472+
_res_update=_var_subst_dnp_update,
473+
limit=limit_res,
474+
# non standard, passed via **kwds, see handling below
475+
)
476+
477+
478+
# noinspection PyUnusedLocal
479+
def _var_subst_dnp_res_init(_, **kwds):
480+
return Counter(), Counter()
481+
482+
483+
def _var_subst_dnp_chunk_q(gp, _edge_var_node_var_and_vars,
484+
values_chunk, limit):
485+
edge_var, node_var, _vars = _edge_var_node_var_and_vars
486+
return gp.to_find_edge_var_for_narrow_path_query(
487+
edge_var=edge_var,
488+
node_var=node_var,
489+
vars_=_vars,
490+
values={_vars: values_chunk},
491+
limit_res=limit)
492+
493+
494+
# noinspection PyUnusedLocal
495+
def _var_subst_dnp_chunk_result_ext(
496+
q_res, _edge_var_node_var_and_vars, _, **kwds):
497+
edge_var, node_var, _vars = _edge_var_node_var_and_vars
498+
chunk_edge_count, chunk_node_sum = Counter(), Counter()
499+
res_rows_path = ['results', 'bindings']
500+
bindings = sparql_json_result_bindings_to_rdflib(
501+
get_path(q_res, res_rows_path, default=[])
502+
)
503+
504+
for row in bindings:
505+
row_res = get_path(row, [edge_var])
506+
edge_count = int(get_path(row, [EDGE_VAR_COUNT], '0'))
507+
chunk_edge_count[row_res] += edge_count
508+
node_sum_count = int(get_path(row, [NODE_VAR_SUM], '0'))
509+
chunk_node_sum[row_res] += node_sum_count
510+
return chunk_edge_count, chunk_node_sum,
511+
512+
513+
def _var_subst_dnp_update(res, up, **_):
514+
edge_count, node_sum_count = res
515+
try:
516+
chunk_edge_count, chunk_node_sum = up
517+
edge_count.update(chunk_edge_count)
518+
node_sum_count.update(chunk_node_sum)
519+
except ValueError:
520+
pass
521+
522+
460523
def generate_stps_from_gp(sparql, gp):
461524
"""Generates a list of source target pairs from a given graph pattern.
462525

graph_pattern.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import six
3232

3333
from utils import URIShortener
34+
import config
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -41,6 +42,8 @@
4142
TARGET_VAR = Variable('target')
4243
ASK_VAR = Variable('ask')
4344
COUNT_VAR = Variable('count')
45+
EDGE_VAR_COUNT = Variable('edge_var_count')
46+
NODE_VAR_SUM = Variable('node_var_sum')
4447

4548

4649
def gen_random_var():
@@ -714,6 +717,86 @@ def to_count_var_over_values_query(self, var, vars_, values, limit):
714717
res += 'LIMIT %d\n' % limit
715718
return self._sparql_prefix(res)
716719

720+
def to_find_edge_var_for_narrow_path_query(
721+
self, edge_var, node_var, vars_, values, limit_res,
722+
filter_node_count=config.MUTPB_DN_FILTER_NODE_COUNT,
723+
filter_edge_count=config.MUTPB_DN_FILTER_EDGE_COUNT,
724+
):
725+
"""Counts possible substitutions for edge_var to get a narrow path
726+
727+
Meant to perform a query like this:
728+
SELECT *
729+
{
730+
{
731+
SELECT
732+
?edge_var
733+
(COUNT(*) AS ?edge_var_count)
734+
(MAX(?node_var_count) AS ?max_node_count)
735+
(COUNT(*)/AVG(?node_var_count) as ?prio_var)
736+
{
737+
SELECT DISTINCT
738+
?source ?target ?edge_var (COUNT(?node_var) AS ?node_var_count)
739+
{
740+
VALUES (?source ?target) {
741+
(dbr:Adolescence dbr:Youth)
742+
(dbr:Adult dbr:Child)
743+
(dbr:Angel dbr:Heaven)
744+
(dbr:Arithmetic dbr:Mathematics)
745+
}
746+
?node_var ?edge_var ?source .
747+
?source dbo:wikiPageWikiLink ?target .
748+
}
749+
}
750+
GROUP BY ?edge_var
751+
ORDER BY DESC(?edge_var_count)
752+
}
753+
FILTER(?max_node_count < 10 && ?edge_var_count > 1)
754+
}
755+
ORDER BY DESC(?prio_var)
756+
LIMIT 32
757+
758+
:param edge_var: Edge variable to find substitution for.
759+
:param node_var: Node variable to count.
760+
:param vars_: List of vars to fix values for (e.g. ?source, ?target).
761+
:param values: List of value lists for vars_.
762+
:param filter_node_count: Filter on node count of edge variable.
763+
:param filter_edge_count: Filter for edge count of triples.
764+
:param limit_res : limit result size
765+
:return: Query String.
766+
"""
767+
768+
res = 'SELECT * WHERE {\n'
769+
res += ' {\n'\
770+
' SELECT %s (SUM (?node_var_count) AS %s) (COUNT(%s) AS %s) ' \
771+
'(MAX(?node_var_count) AS ?max_node_count) WHERE {\n' % (
772+
edge_var.n3(),
773+
NODE_VAR_SUM.n3(),
774+
' && '.join([v.n3() for v in vars_]),
775+
EDGE_VAR_COUNT.n3(), )
776+
res += ' SELECT DISTINCT %s %s (COUNT(%s) AS ?node_var_count) ' \
777+
'WHERE {\n ' % (' '.join([v.n3() for v in vars_]),
778+
edge_var.n3(), node_var.n3(), )
779+
res += self._sparql_values_part(values)
780+
781+
# triples part
782+
tres = []
783+
for s, p, o in self:
784+
tres.append('%s %s %s .' % (s.n3(), p.n3(), o.n3()))
785+
indent = ' ' * 3
786+
triples = indent + ('\n' + indent).join(tres) + '\n'
787+
res += triples
788+
res += ' }\n'\
789+
' }\n'
790+
res += ' GROUP BY %s\n' % edge_var.n3()
791+
res += ' }\n'
792+
res += ' FILTER(?max_node_count < %d && %s > %d)\n' \
793+
% (filter_node_count, EDGE_VAR_COUNT.n3(),
794+
filter_edge_count)
795+
res += '}\n'
796+
res += 'ORDER BY ASC(%s)\n' % NODE_VAR_SUM.n3()
797+
res += 'LIMIT %d' % limit_res
798+
return self._sparql_prefix(res)
799+
717800
def to_dict(self):
718801
return {
719802
'fitness': self.fitness.values if self.fitness.valid else (),

tests/test_gp_learner_offline.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from gp_learner import mutate_increase_dist
1414
from gp_learner import mutate_merge_var
1515
from gp_learner import mutate_simplify_pattern
16+
from gp_learner import mutate_deep_narrow_path
1617
from graph_pattern import GraphPattern
1718
from graph_pattern import SOURCE_VAR
1819
from graph_pattern import TARGET_VAR
@@ -108,6 +109,35 @@ def test_mutate_merge_var():
108109
assert False, "merge never reached one of the cases: %s" % cases
109110

110111

112+
def test_mutate_deep_narrow_path():
113+
p = Variable('p')
114+
gp = GraphPattern([
115+
(SOURCE_VAR, p, TARGET_VAR)
116+
])
117+
child = mutate_deep_narrow_path(gp)
118+
assert gp == child or len(child) > len(gp)
119+
print(gp)
120+
print(child)
121+
122+
123+
def test_to_find_edge_var_for_narrow_path_query():
124+
node_var = Variable('node_variable')
125+
edge_var = Variable('edge_variable')
126+
gp = GraphPattern([
127+
(node_var, edge_var, SOURCE_VAR),
128+
(SOURCE_VAR, wikilink, TARGET_VAR)
129+
])
130+
filter_node_count = 10
131+
filter_edge_count = 1
132+
limit_res = 32
133+
vars_ = {SOURCE_VAR,TARGET_VAR}
134+
res = GraphPattern.to_find_edge_var_for_narrow_path_query(gp, edge_var, node_var,
135+
vars_, filter_node_count,
136+
filter_edge_count, limit_res)
137+
print(gp)
138+
print(res)
139+
140+
111141
def test_simplify_pattern():
112142
gp = GraphPattern([(SOURCE_VAR, wikilink, TARGET_VAR)])
113143
res = mutate_simplify_pattern(gp)
@@ -270,3 +300,8 @@ def test_remaining_gain_sample_gtps():
270300

271301
def test_gtp_scores():
272302
assert gtp_scores - gtp_scores == 0
303+
304+
305+
if __name__ == '__main__':
306+
# test_mutate_deep_narrow_path()
307+
test_to_find_edge_var_for_narrow_path_query()

0 commit comments

Comments
 (0)