Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/pubget/_data/stylesheets/text_extraction.xsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

<xsl:output method="xml" version="1.0" encoding="UTF-8" omit-xml-declaration="no"/>
<xsl:strip-space elements="*"/>
<xsl:param name="preserve-crossrefs" select="'true'"/>

<xsl:template match="/">
<extracted-text>
Expand Down Expand Up @@ -305,6 +306,13 @@
<xsl:template match="volume-series" />
<xsl:template match="word-count" />
<xsl:template match="x" />
<xsl:template match="xref" />
<xsl:template match="xref">
<xsl:choose>
<xsl:when test="$preserve-crossrefs = 'true'">
<xsl:apply-templates />
</xsl:when>
<xsl:otherwise />
</xsl:choose>
</xsl:template>
<xsl:template match="year" />
</xsl:transform>
26 changes: 22 additions & 4 deletions src/pubget/_data_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def extract_data_to_csv(
output_dir: Optional[PathLikeOrStr] = None,
*,
articles_with_coords_only: bool = False,
preserve_cross_references: bool = True,
n_jobs: int = 1,
) -> Tuple[Path, ExitCode]:
"""Extract text and coordinates from articles and store in csv files.
Expand All @@ -158,6 +159,9 @@ def extract_data_to_csv(
`articles_with_coords_only`.
articles_with_coords_only
If True, articles that contain no stereotactic coordinates are ignored.
preserve_cross_references
If True, text from inline cross-reference elements is preserved in
extracted text. If False, those elements are removed.
n_jobs
Number of processes to run in parallel. `-1` means using all
processors.
Expand Down Expand Up @@ -190,7 +194,11 @@ def extract_data_to_csv(
)
n_jobs = _utils.check_n_jobs(n_jobs)
n_articles = _do_extract_data_to_csv(
articles_dir, output_dir, articles_with_coords_only, n_jobs=n_jobs
articles_dir,
output_dir,
articles_with_coords_only,
preserve_cross_references=preserve_cross_references,
n_jobs=n_jobs,
)
is_complete = bool(status["previous_step_complete"])
_utils.write_info(
Expand All @@ -204,11 +212,13 @@ def extract_data_to_csv(
return output_dir, exit_code


def _get_data_extractors() -> List[Extractor]:
def _get_data_extractors(
preserve_cross_references: bool,
) -> List[Extractor]:
return [
MetadataExtractor(),
AuthorsExtractor(),
TextExtractor(),
TextExtractor(preserve_cross_references=preserve_cross_references),
TableInfoExtractor(),
CoordinateExtractor(),
CoordinateSpaceExtractor(),
Expand All @@ -221,14 +231,15 @@ def _do_extract_data_to_csv(
articles_dir: Path,
output_dir: Path,
articles_with_coords_only: bool,
preserve_cross_references: bool,
n_jobs: int,
) -> int:
"""Do the data extraction and return the number of articles whose data was
saved. If `articles_with_coords_only` only articles with at least one
sterotactic coordinate triplet have their data saved.
"""
n_to_process = _utils.get_n_articles(articles_dir)
data_extractors = _get_data_extractors()
data_extractors = _get_data_extractors(preserve_cross_references)
all_writers = [
CSVWriter.from_extractor(extractor, output_dir)
for extractor in data_extractors
Expand Down Expand Up @@ -289,6 +300,11 @@ def _edit_argument_parser(
help="Only keep data for articles in which stereotactic coordinates "
"are found.",
)
argument_parser.add_argument(
"--strip-cross-references",
action="store_true",
help="Remove inline cross-reference text from extracted article text.",
)
_utils.add_n_jobs_argument(argument_parser)


Expand All @@ -309,6 +325,7 @@ def run(
output_dir, exit_code = extract_data_to_csv(
previous_steps_output["extract_articles"],
articles_with_coords_only=args.articles_with_coords_only,
preserve_cross_references=not args.strip_cross_references,
n_jobs=args.n_jobs,
)
if not _utils.get_n_articles(output_dir):
Expand Down Expand Up @@ -343,5 +360,6 @@ def run(self, args: argparse.Namespace) -> ExitCode:
return extract_data_to_csv(
args.articles_dir,
articles_with_coords_only=args.articles_with_coords_only,
preserve_cross_references=not args.strip_cross_references,
n_jobs=args.n_jobs,
)[1]
12 changes: 11 additions & 1 deletion src/pubget/_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TextExtractor(Extractor):
fields = ("pmcid", "title", "keywords", "abstract", "body")
name = "text"

def __init__(self, preserve_cross_references: bool = True) -> None:
self.preserve_cross_references = preserve_cross_references

def extract(
self,
article: etree.ElementTree,
Expand All @@ -30,7 +33,14 @@ def extract(
# multiprocessing map. Parsing is cached.
stylesheet = _utils.load_stylesheet("text_extraction.xsl")
try:
transformed = stylesheet(article)
transformed = stylesheet(
article,
**{
"preserve-crossrefs": etree.XSLT.strparam(
"true" if self.preserve_cross_references else "false"
)
},
)
except Exception:
_LOG.exception(
f"failed to transform article: {stylesheet.error_log}"
Expand Down
29 changes: 29 additions & 0 deletions tests/test_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import Mock

from lxml import etree
Expand All @@ -12,3 +13,31 @@ def test_text_extractor_transform_failure(monkeypatch):
etree, "XSLT", Mock(return_value=Mock(side_effect=ValueError))
)
assert extractor.extract(Mock(), Mock(), {}) == {}


def test_text_extractor_preserves_xref_text_by_default():
extractor = _text.TextExtractor()
article = etree.fromstring(
b"""<article>
<body>
<p>Example ( <xref ref-type='bibr'>Doe et al.</xref> ) test.</p>
</body>
</article>"""
)
result = extractor.extract(article, Path("."), {})
assert "Doe et al." in result["body"]
assert "( )" not in result["body"]


def test_text_extractor_strips_xref_text_when_disabled():
extractor = _text.TextExtractor(preserve_cross_references=False)
article = etree.fromstring(
b"""<article>
<body>
<p>Example ( <xref ref-type='bibr'>Doe et al.</xref> ) test.</p>
</body>
</article>"""
)
result = extractor.extract(article, Path("."), {})
assert "Doe et al." not in result["body"]
assert "( )" in result["body"]