diff --git a/src/pubget/_data/stylesheets/text_extraction.xsl b/src/pubget/_data/stylesheets/text_extraction.xsl index 885254c..5ea266a 100644 --- a/src/pubget/_data/stylesheets/text_extraction.xsl +++ b/src/pubget/_data/stylesheets/text_extraction.xsl @@ -7,6 +7,7 @@ + @@ -305,6 +306,13 @@ - + + + + + + + + diff --git a/src/pubget/_data_extraction.py b/src/pubget/_data_extraction.py index 0a6911d..b4822fc 100644 --- a/src/pubget/_data_extraction.py +++ b/src/pubget/_data_extraction.py @@ -139,6 +139,7 @@ def extract_data_to_csv( output_dir: Optional[PathLikeOrStr] = None, *, articles_with_coords_only: bool = False, + preserve_cross_references: bool = True, n_jobs: int = 1, ) -> Tuple[Path, ExitCode]: """Extract text and coordinates from articles and store in csv files. @@ -158,6 +159,9 @@ def extract_data_to_csv( `articles_with_coords_only`. articles_with_coords_only If True, articles that contain no stereotactic coordinates are ignored. + preserve_cross_references + If True, text from inline cross-reference elements is preserved in + extracted text. If False, those elements are removed. n_jobs Number of processes to run in parallel. `-1` means using all processors. @@ -190,7 +194,11 @@ def extract_data_to_csv( ) n_jobs = _utils.check_n_jobs(n_jobs) n_articles = _do_extract_data_to_csv( - articles_dir, output_dir, articles_with_coords_only, n_jobs=n_jobs + articles_dir, + output_dir, + articles_with_coords_only, + preserve_cross_references=preserve_cross_references, + n_jobs=n_jobs, ) is_complete = bool(status["previous_step_complete"]) _utils.write_info( @@ -204,11 +212,13 @@ def extract_data_to_csv( return output_dir, exit_code -def _get_data_extractors() -> List[Extractor]: +def _get_data_extractors( + preserve_cross_references: bool, +) -> List[Extractor]: return [ MetadataExtractor(), AuthorsExtractor(), - TextExtractor(), + TextExtractor(preserve_cross_references=preserve_cross_references), TableInfoExtractor(), CoordinateExtractor(), CoordinateSpaceExtractor(), @@ -221,6 +231,7 @@ def _do_extract_data_to_csv( articles_dir: Path, output_dir: Path, articles_with_coords_only: bool, + preserve_cross_references: bool, n_jobs: int, ) -> int: """Do the data extraction and return the number of articles whose data was @@ -228,7 +239,7 @@ def _do_extract_data_to_csv( sterotactic coordinate triplet have their data saved. """ n_to_process = _utils.get_n_articles(articles_dir) - data_extractors = _get_data_extractors() + data_extractors = _get_data_extractors(preserve_cross_references) all_writers = [ CSVWriter.from_extractor(extractor, output_dir) for extractor in data_extractors @@ -289,6 +300,11 @@ def _edit_argument_parser( help="Only keep data for articles in which stereotactic coordinates " "are found.", ) + argument_parser.add_argument( + "--strip-cross-references", + action="store_true", + help="Remove inline cross-reference text from extracted article text.", + ) _utils.add_n_jobs_argument(argument_parser) @@ -309,6 +325,7 @@ def run( output_dir, exit_code = extract_data_to_csv( previous_steps_output["extract_articles"], articles_with_coords_only=args.articles_with_coords_only, + preserve_cross_references=not args.strip_cross_references, n_jobs=args.n_jobs, ) if not _utils.get_n_articles(output_dir): @@ -343,5 +360,6 @@ def run(self, args: argparse.Namespace) -> ExitCode: return extract_data_to_csv( args.articles_dir, articles_with_coords_only=args.articles_with_coords_only, + preserve_cross_references=not args.strip_cross_references, n_jobs=args.n_jobs, )[1] diff --git a/src/pubget/_text.py b/src/pubget/_text.py index 847544d..ea8e84d 100644 --- a/src/pubget/_text.py +++ b/src/pubget/_text.py @@ -17,6 +17,9 @@ class TextExtractor(Extractor): fields = ("pmcid", "title", "keywords", "abstract", "body") name = "text" + def __init__(self, preserve_cross_references: bool = True) -> None: + self.preserve_cross_references = preserve_cross_references + def extract( self, article: etree.ElementTree, @@ -30,7 +33,14 @@ def extract( # multiprocessing map. Parsing is cached. stylesheet = _utils.load_stylesheet("text_extraction.xsl") try: - transformed = stylesheet(article) + transformed = stylesheet( + article, + **{ + "preserve-crossrefs": etree.XSLT.strparam( + "true" if self.preserve_cross_references else "false" + ) + }, + ) except Exception: _LOG.exception( f"failed to transform article: {stylesheet.error_log}" diff --git a/tests/test_text.py b/tests/test_text.py index 19a30ed..73c570e 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import Mock from lxml import etree @@ -12,3 +13,31 @@ def test_text_extractor_transform_failure(monkeypatch): etree, "XSLT", Mock(return_value=Mock(side_effect=ValueError)) ) assert extractor.extract(Mock(), Mock(), {}) == {} + + +def test_text_extractor_preserves_xref_text_by_default(): + extractor = _text.TextExtractor() + article = etree.fromstring( + b"""
+ +

Example ( Doe et al. ) test.

+ +
""" + ) + result = extractor.extract(article, Path("."), {}) + assert "Doe et al." in result["body"] + assert "( )" not in result["body"] + + +def test_text_extractor_strips_xref_text_when_disabled(): + extractor = _text.TextExtractor(preserve_cross_references=False) + article = etree.fromstring( + b"""
+ +

Example ( Doe et al. ) test.

+ +
""" + ) + result = extractor.extract(article, Path("."), {}) + assert "Doe et al." not in result["body"] + assert "( )" in result["body"]