Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 100 additions & 80 deletions src/semra/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ssslm import LiteralMapping
from tqdm.auto import tqdm

from semra.io.graph import _from_digraph_edge, to_digraph
from semra.io.graph import to_digraph
from semra.rules import EXACT_MATCH, FLIP, INVERSION_MAPPING, SubsetConfiguration
from semra.struct import (
Evidence,
Expand Down Expand Up @@ -55,7 +55,6 @@
"get_index",
"get_many_to_many",
"get_observed_terms",
"get_priority_reference",
"get_symmetric_counter",
"get_terms",
"get_test_evidence",
Expand Down Expand Up @@ -509,7 +508,11 @@ def assert_projection(mappings: list[Mapping]) -> None:


def prioritize(
mappings: list[Mapping], priority: list[str], *, progress: bool = True
mappings: list[Mapping],
priority: list[str],
*,
progress: bool = True,
sort: bool = True,
) -> list[Mapping]:
"""Get a priority star graph.

Expand All @@ -523,6 +526,7 @@ def prioritize(
if there exists a mapping ``A, exact, B``, there must be a ``B, exact, A``.

:param priority: A priority list of prefixes, where earlier in the list means the priority is higher.
:param sort: Sort by object then subject?
:return:
A list of mappings representing a "prioritization", meaning that each element only
appears as subject once. This condition means that the prioritization mapping can be applied
Expand All @@ -531,7 +535,13 @@ def prioritize(
This algorithm works in the following way

1. Get the subset of exact matches from the input mapping list
2. Convert the exact matches to an undirected mapping graph
2. Convert the exact matches to an index that's like undirected mapping graph.

.. warning::

This assumes that all evidences have been aggregated into a single mapping!
Make sure you run :func:`assemble_evidences` first

3. Extract connected components.

.. note::
Expand All @@ -558,45 +568,37 @@ def prioritize(
>>> mappings = infer_chains(mappings)
>>> prioritize(mappings, ["mesh", "doid", "umls"])
"""
original_mappings = len(mappings)
mappings = [m for m in mappings if m.predicate == EXACT_MATCH]
exact_mappings = len(mappings)
priority = _clean_priority_prefixes(priority)

graph = to_digraph(mappings).to_undirected()
rv: list[Mapping] = []
for component in tqdm(
nx.connected_components(graph), unit="component", unit_scale=True, disable=not progress
):
o = get_priority_reference(component, priority)
if o is None:

aggregator = Aggregator(priority)
subject_object_mapping, original_mappings, exact_mappings = _get_index(mappings)

for s, object_mapping in subject_object_mapping.items():
o_key, o = aggregator.get_priority_reference(object_mapping)
if o == s:
continue
for s in component:
if s == o: # don't add self-edges
continue
if not graph.has_edge(s, o):
# TODO should this work even if s-o edge not exists?
# can also do "inference" here, but also might be
# because of negative edge filtering
raise NotImplementedError(
"prioritize() should only be called on fully inferred graphs, meaning "
"that in a given component, it is a full clique (i.e., there are edges "
"in both directions between all nodes)"
)
rv.extend(_from_digraph_edge(graph, s, o))

# sort such that the mappings are ordered by object by priority order
# then identifier of object, then subject prefix in alphabetical order
pos = {prefix: i for i, prefix in enumerate(priority)}
rv = sorted(
rv,
key=lambda m: (
pos[m.object.prefix],
m.object.identifier,
m.subject.prefix,
m.subject.identifier,
),
)

s_key = aggregator.get_reference_key(s)

# when the object key is smaller than the subject key,
# we prioritized in the right direction
if s_key > o_key:
rv.append(subject_object_mapping[s][o])
elif o in subject_object_mapping and s in subject_object_mapping[o]:
raise NotImplementedError
else:
flipped_mapping = flip(subject_object_mapping[s][o], strict=True)
rv.append(flipped_mapping)

if sort:
rv = sorted(
rv,
key=lambda m: (
aggregator.get_reference_key(m.object),
m.subject.prefix,
m.subject.identifier,
),
)

end_mappings = len(rv)
logger.info(
Expand All @@ -605,46 +607,64 @@ def prioritize(
return rv


def _clean_priority_prefixes(priority: list[str]) -> list[str]:
return [bioregistry.normalize_prefix(prefix, strict=True) for prefix in priority]


def get_priority_reference(
component: t.Iterable[Reference], priority: list[str]
) -> Reference | None:
"""Get the priority reference from a component.

:param component: A set of references with the pre-condition that they're all "equivalent"
:param priority: A priority list of prefixes, where earlier in the list means the priority is higher
:returns:
Returns the reference with the prefix that has the highest priority.
If multiple references have the highest priority prefix, returns the first one encountered.
If none have a priority prefix, return None.

>>> from semra import Reference
>>> curies = ["DOID:0050577", "mesh:C562966", "umls:C4551571"]
>>> references = [Reference.from_curie(curie) for curie in curies]
>>> get_priority_reference(references, ["mesh", "umls"]).curie
'mesh:C562966'
>>> get_priority_reference(references, ["DOID", "mesh", "umls"]).curie
'doid:0050577'
>>> get_priority_reference(references, ["hpo", "ordo", "symp"])
def _get_index(
mappings: Iterable[Mapping],
) -> tuple[dict[Reference, dict[Reference, Mapping]], int, int]:
original_mappings = 0
exact_mappings = 0

"""
prefix_to_references: defaultdict[str, list[Reference]] = defaultdict(list)
for reference in component:
prefix_to_references[reference.prefix].append(reference)
for prefix in _clean_priority_prefixes(priority):
references = prefix_to_references.get(prefix, [])
if not references:
continue
if len(references) == 1:
return references[0]
# TODO multiple - I guess let's just return the first
logger.debug("multiple references for %s", prefix)
return references[0]
# nothing found in priority, don't return at all.
return None
subject_object_mapping: defaultdict[Reference, dict[Reference, Mapping]] = defaultdict(dict)
for mapping in mappings:
original_mappings += 1
if mapping.predicate == EXACT_MATCH:
exact_mappings += 1
subject_object_mapping[mapping.subject][mapping.object] = mapping

# need to rasterize, otherwise dictionary size could
# change during iteration in case we try and access
# an element that doesn't exist
return dict(subject_object_mapping), original_mappings, exact_mappings


ReferenceKey: TypeAlias = tuple[int, str, str]


class Aggregator:
"""A class for aggregating nodes based on a priority list."""

def __init__(self, priority: Iterable[str]) -> None:
"""Initialize an aggregator."""
priority = [bioregistry.normalize_prefix(prefix, strict=True) for prefix in priority]
# sort such that the mappings are ordered by object by priority order
# then identifier of object, then subject prefix in alphabetical order
self.pos = {prefix: i for i, prefix in enumerate(priority)}
self.n = len(self.pos) + 1

def get_reference_key(self, node: Reference) -> ReferenceKey:
"""Get a sort key for a node based on priority, prefix, then identifier."""
# sort by both prefix priority, then also prefix to tiebrake
# when none are prioritized, then identifier within vocabulary
return self.pos.get(node.prefix, self.n), node.prefix, node.identifier

def get_priority_reference(self, nodes: Iterable[Reference]) -> tuple[ReferenceKey, Reference]:
"""Get a unique priority reference from a set of references.

:param nodes: The collection of references to get the priority reference from
:returns: A pair of the "reference key" and the priority reference

Example:
>>> from semra import Reference
>>> from semra.api import Aggregator
>>> curies = ["DOID:0050577", "mesh:C562966", "umls:C4551571"]
>>> references = [Reference.from_curie(curie) for curie in curies]
>>> Aggregator(["mesh", "umls"]).get_priority_reference(references)[1].curie
'mesh:C562966'
>>> Aggregator(["DOID", "mesh", "umls"]).get_priority_reference(references)[1].curie
'doid:0050577'
>>> Aggregator(["hpo", "ordo", "symp"]).get_priority_reference(references)[1].curie
'doid:0050577'
"""
return min((self.get_reference_key(n), n) for n in nodes)


def unindex(index: Index, *, progress: bool = True) -> list[Mapping]:
Expand Down
37 changes: 36 additions & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Reusable assets for testing."""

from semra import Reference
from __future__ import annotations

import unittest

from semra import Mapping, Reference
from semra.api import Index, get_index

a1_curie = "CHEBI:10084" # Xylopinine
a2_curie = "CHEBI:10100" # zafirlukast
Expand All @@ -16,3 +21,33 @@
)

TEST_CURIES = {a1, a2, b1, b2}


class BaseTestCase(unittest.TestCase):
"""A test case with functionality for testing mapping equivalence."""

def assert_same_triples(
self,
expected_mappings: Index | list[Mapping],
actual_mappings: Index | list[Mapping],
msg: str | None = None,
) -> None:
"""Assert that two sets of mappings are the same."""
if not isinstance(expected_mappings, dict):
expected_mappings = get_index(expected_mappings, progress=False)
if not isinstance(actual_mappings, dict):
actual_mappings = get_index(actual_mappings, progress=False)

self.assertEqual(
self._clean_index(expected_mappings),
self._clean_index(actual_mappings),
msg=msg,
)

@staticmethod
def _clean_index(index: Index) -> list[str]:
triples = sorted(set(index))
return [
f"<{triple.subject.curie}, {triple.predicate.curie}, {triple.object.curie}>"
for triple in triples
]
Loading