diff --git a/.github/workflows/test_pypi.yml b/.github/workflows/test_pypi.yml index 546df40..938488e 100644 --- a/.github/workflows/test_pypi.yml +++ b/.github/workflows/test_pypi.yml @@ -26,5 +26,6 @@ jobs: python -c "import paperscraper.server_dumps" python -c "import paperscraper.tests" python -c "import paperscraper.impact" + python -c "import paperscraper.citations" diff --git a/.github/workflows/test_tip.yml b/.github/workflows/test_tip.yml index 599f068..7dd2dd6 100644 --- a/.github/workflows/test_tip.yml +++ b/.github/workflows/test_tip.yml @@ -5,6 +5,8 @@ on: [push, release] jobs: test-source-install: runs-on: ubuntu-latest + env: + SS_API_KEY: ${{ secrets.SS_API_KEY }} strategy: max-parallel: 3 matrix: @@ -49,6 +51,8 @@ jobs: test-potential-wheel-install: runs-on: ubuntu-latest + env: + SS_API_KEY: ${{ secrets.SS_API_KEY }} steps: - name: Checkout code uses: actions/checkout@v2 diff --git a/paperscraper/async_utils.py b/paperscraper/async_utils.py index a4bc596..816ce27 100644 --- a/paperscraper/async_utils.py +++ b/paperscraper/async_utils.py @@ -49,14 +49,20 @@ def wrapper(*args, **kwargs) -> Union[T, Awaitable[T]]: def retry_with_exponential_backoff( - *, max_retries: int = 5, base_delay: float = 1.0 + *, + max_retries: int = 5, + base_delay: float = 1.0, + factor: float = 1.3, + constant_delay: float = 0.2, ) -> Callable[[F], F]: """ Decorator factory that retries an `async def` on HTTP 429, with exponential backoff. Args: max_retries: how many times to retry before giving up. - base_delay: initial delay in seconds; next delays will be duplication of previous. + base_delay: initial delay in seconds; next delays will be multiplied by `factor`. + factor: multiplier for delay after each retry. + constant_delay: fixed delay before each attempt. Usage: @@ -70,18 +76,39 @@ def decorator(func: F) -> F: @wraps(func) async def wrapper(*args, **kwargs) -> Any: delay = base_delay - for attempt in range(max_retries): + last_exception: BaseException | None = None + for attempt in range(1, max_retries + 1): + await asyncio.sleep(constant_delay) try: return await func(*args, **kwargs) except httpx.HTTPStatusError as e: - # only retry on 429 status = e.response.status_code if e.response is not None else None - if status != 429 or attempt == max_retries - 1: + if status != 429: raise - # backoff - await asyncio.sleep(delay) - delay *= 2 - # in theory we never reach here + last_exception = e + sleep_for = delay + if e.response is not None: + ra = e.response.headers.get("Retry-After") + if ra is not None: + try: + sleep_for = float(ra) + except ValueError: + pass + delay *= factor + + except httpx.ReadError as e: + last_exception = e + sleep_for = delay + delay *= factor + + if attempt == max_retries: + msg = ( + f"{func.__name__} failed after {attempt} attempts with " + f"last delay {sleep_for:.2f}s" + ) + raise RuntimeError(msg) from last_exception + + await asyncio.sleep(sleep_for) return wrapper diff --git a/paperscraper/citations/entity/core.py b/paperscraper/citations/entity/core.py index b25b035..04b14ca 100644 --- a/paperscraper/citations/entity/core.py +++ b/paperscraper/citations/entity/core.py @@ -5,14 +5,15 @@ class EntityResult(BaseModel): - num_citations: int - num_references: int - # keys are authors or papers and values are absolute self links - self_citations: Dict[str, int] = {} - self_references: Dict[str, int] = {} # aggregated results self_citation_ratio: float = 0 self_reference_ratio: float = 0 + # total number of author citations/references + num_citations: int + num_references: int + # keys are papers and values are percentage of self citations/references + self_citations: Dict[str, float] = {} + self_references: Dict[str, float] = {} class Entity: diff --git a/paperscraper/citations/entity/paper.py b/paperscraper/citations/entity/paper.py index 8530507..ad51a11 100644 --- a/paperscraper/citations/entity/paper.py +++ b/paperscraper/citations/entity/paper.py @@ -68,14 +68,14 @@ def self_references(self): Extracts the self references of a paper, for each author. """ if isinstance(self.doi, str): - self.ref_result: ReferenceResult = self_references_paper(self.doi) + self.self_ref: ReferenceResult = self_references_paper(self.doi) def self_citations(self): """ Extracts the self citations of a paper, for each author. """ if isinstance(self.doi, str): - self.citation_result: CitationResult = self_citations_paper(self.doi) + self.self_cite: CitationResult = self_citations_paper(self.doi) def get_result(self) -> Optional[PaperResult]: """ @@ -83,18 +83,26 @@ def get_result(self) -> Optional[PaperResult]: Returns: PaperResult if available. """ - if not hasattr(self, "ref_result"): + if not hasattr(self, "self_ref"): logger.warning( f"Can't get result since no referencing result for {self.input} exists. Run `.self_references` first." ) return - elif not hasattr(self, "citation_result"): + elif not hasattr(self, "self_cite"): logger.warning( f"Can't get result since no citation result for {self.input} exists. Run `.self_citations` first." ) return - ref_result = self.ref_result.model_dump() - ref_result.pop("ssid", None) return PaperResult( - title=self.title, **ref_result, **self.citation_result.model_dump() + title=self.title, + **{ + k: v + for k, v in self.self_ref.model_dump().items() + if k not in ["ssid", "title"] + }, + **{ + k: v + for k, v in self.self_cite.model_dump().items() + if k not in ["title"] + }, ) diff --git a/paperscraper/citations/entity/researcher.py b/paperscraper/citations/entity/researcher.py index 4a27312..e6a256f 100644 --- a/paperscraper/citations/entity/researcher.py +++ b/paperscraper/citations/entity/researcher.py @@ -1,11 +1,12 @@ +import asyncio import os -from typing import List, Literal, Optional +from typing import Any, List, Literal, Optional, Tuple from semanticscholar import SemanticScholar -from tqdm import tqdm from ..orcid import orcid_to_author_name -from ..self_references import ReferenceResult +from ..self_citations import CitationResult +from ..self_references import ReferenceResult, self_references_paper from ..utils import author_name_to_ssaid, get_papers_for_author from .core import Entity, EntityResult @@ -14,7 +15,27 @@ class ResearcherResult(EntityResult): name: str ssid: int orcid: Optional[str] = None - # TODO: the ratios will be averaged across all papers for that author + + def _ordered_items(self) -> List[Tuple[str, Any]]: + # enforce specific ordering + return [ + ("name", self.name), + ("self_reference_ratio", self.self_reference_ratio), + ("self_citation_ratio", self.self_citation_ratio), + ("num_references", self.num_references), + ("num_citations", self.num_citations), + ("self_references", self.self_references), + ("self_citations", self.self_citations), + ("ssid", self.ssid), + ("orcid", self.orcid), + ] + + def __repr__(self) -> str: + inner = ", ".join(f"{k}={v!r}" for k, v in self._ordered_items()) + return f"{self.__class__.__name__}({inner})" + + def __str__(self) -> str: + return " ".join(f"{k}={v!r}" for k, v in self._ordered_items()) ModeType = Literal[tuple(MODES := ("name", "orcid", "ssaid", "infer"))] @@ -32,7 +53,7 @@ def __init__(self, input: str, mode: ModeType = "infer"): Construct researcher object for self citation/reference analysis. Args: - input: A researcher to search for. + input: A researcher to search for, identified by name, ORCID iD, or Semantic Scholar Author ID. mode: This can be a `name` `orcid` (ORCID iD) or `ssaid` (Semantic Scholar Author ID). Defaults to "infer". @@ -53,32 +74,74 @@ def __init__(self, input: str, mode: ModeType = "infer"): ): mode = "orcid" else: - mode = "author" - + mode = "name" if mode == "ssaid": - self.author = sch.get_author(input) + self.name = sch.get_author(input)._name self.ssid = input elif mode == "orcid": - self.author = orcid_to_author_name(input) + orcid_name = orcid_to_author_name(input) self.orcid = input - self.ssid = author_name_to_ssaid(input) - elif mode == "author": - self.author = input - self.ssid = author_name_to_ssaid(input) - - # TODO: Skip over erratum / corrigendum - self.ssids = get_papers_for_author(self.ssid) - - def self_references(self): + self.ssid, self.name = author_name_to_ssaid(orcid_name) + elif mode == "name": + name = input + self.ssid, self.name = author_name_to_ssaid(input) + + async def _self_references_async( + self, verbose: bool = False + ) -> List[ReferenceResult]: + """Async version of self_references.""" + self.ssids = await get_papers_for_author(self.ssid) + + results: List[ReferenceResult] = await self_references_paper( + self.ssids, verbose=verbose + ) + # Remove papers with zero references or that are erratum/corrigendum + results = [ + r + for r in results + if r.num_references > 0 + and "erratum" not in r.title.lower() + and "corrigendum" not in r.title.lower() + ] + + return results + + def self_references(self, verbose: bool = False) -> ResearcherResult: """ Sifts through all papers of a researcher and extracts the self references. - """ - # TODO: Asynchronous call to self_references - print("Going through SSIDs", self.ssids) - # TODO: Aggregate results + Args: + verbose: If True, logs detailed information for each paper. - def self_citations(self): + Returns: + A ResearcherResult containing aggregated self-reference data. + """ + reference_results = asyncio.run(self._self_references_async(verbose=verbose)) + + individual_self_references = { + getattr(result, "title"): getattr(result, "self_references").get(self.name, 0.0) + for result in reference_results + } + reference_ratio = sum(individual_self_references.values()) / max(1, len( + individual_self_references + )) + return ResearcherResult( + name=self.name, + ssid=int(self.ssid), + orcid=self.orcid, + num_references=sum(r.num_references for r in reference_results), + num_citations=-1, + self_references=dict( + sorted( + individual_self_references.items(), key=lambda x: x[1], reverse=True + ) + ), + self_citations={}, + self_reference_ratio=round(reference_ratio, 3), + self_citation_ratio=-1.0, + ) + + def self_citations(self) -> ResearcherResult: """ Sifts through all papers of a researcher and finds how often they are self-cited. """ diff --git a/paperscraper/citations/self_citations.py b/paperscraper/citations/self_citations.py index 57cef83..9d0b704 100644 --- a/paperscraper/citations/self_citations.py +++ b/paperscraper/citations/self_citations.py @@ -18,6 +18,7 @@ class CitationResult(BaseModel): ssid: str # semantic scholar paper id + title: str num_citations: int self_citations: Dict[str, float] = {} citation_score: float @@ -87,6 +88,7 @@ async def _process_single(client: httpx.AsyncClient, identifier: str) -> Citatio return CitationResult( ssid=identifier, + title=paper.get("title", ""), num_citations=total_cites, self_citations=ratios, citation_score=avg_score, diff --git a/paperscraper/citations/self_references.py b/paperscraper/citations/self_references.py index 5014377..3ddffd0 100644 --- a/paperscraper/citations/self_references.py +++ b/paperscraper/citations/self_references.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import re import sys from typing import Any, Dict, List, Literal, Union @@ -7,6 +8,7 @@ import httpx import numpy as np from pydantic import BaseModel +from tqdm import tqdm from ..async_utils import optional_async, retry_with_exponential_backoff from .utils import DOI_PATTERN, find_matching @@ -17,14 +19,25 @@ ModeType = Literal[tuple(MODES := ("doi", "infer", "ssid"))] +SS_API_KEY = os.getenv("SS_API_KEY") +HEADERS: Dict[str, str] = {} +if SS_API_KEY: + HEADERS["x-api-key"] = SS_API_KEY + +CONCURRENCY_LIMIT = 10 +_SEM = asyncio.Semaphore(CONCURRENCY_LIMIT) + + class ReferenceResult(BaseModel): ssid: str # semantic scholar paper id + title: str num_references: int self_references: Dict[str, float] = {} reference_score: float -async def _fetch_reference_data( +@retry_with_exponential_backoff(max_retries=14, base_delay=1.0) +async def _fetch_paper_with_references( client: httpx.AsyncClient, suffix: str ) -> Dict[str, Any]: """ @@ -40,6 +53,7 @@ async def _fetch_reference_data( response = await client.get( f"https://api.semanticscholar.org/graph/v1/paper/{suffix}", params={"fields": "title,authors,references.authors"}, + headers=HEADERS, ) response.raise_for_status() return response.json() @@ -58,47 +72,49 @@ async def _process_single_reference( Returns: A ReferenceResult containing counts and percentages of self-references. """ - # Determine prefix for API - if len(identifier) > 15 and identifier.isalnum() and identifier.islower(): - prefix = "" - elif len(re.findall(DOI_PATTERN, identifier, re.IGNORECASE)) == 1: - prefix = "DOI:" - else: - prefix = "" - - suffix = f"{prefix}{identifier}" - paper = await _fetch_reference_data(client, suffix) - - # Initialize counters - author_counts: Dict[str, int] = {a["name"]: 0 for a in paper.get("authors", [])} - references = paper.get("references", []) - total_refs = len(references) - - # Tally self-references - for ref in references: - matched = find_matching(paper.get("authors", []), ref.get("authors", [])) - for name in matched: - author_counts[name] += 1 - - # Compute percentages per author - ratios: Dict[str, float] = { - name: round((count / total_refs * 100), 2) if total_refs > 0 else 0.0 - for name, count in author_counts.items() - } - - # Compute average score - avg_score = round(float(np.mean(list(ratios.values()))) if ratios else 0.0, 3) - - return ReferenceResult( - ssid=identifier, - num_references=total_refs, - self_references=ratios, - reference_score=avg_score, - ) + async with _SEM: + # Determine prefix for API + if len(identifier) > 15 and identifier.isalnum() and identifier.islower(): + prefix = "" + elif len(re.findall(DOI_PATTERN, identifier, re.IGNORECASE)) == 1: + prefix = "DOI:" + else: + prefix = "" + + suffix = f"{prefix}{identifier}" + paper = await _fetch_paper_with_references(client, suffix) + + # Initialize counters + author_counts: Dict[str, int] = {a["name"]: 0 for a in paper.get("authors", [])} + references = paper.get("references", []) or [] + + total_refs = len(references) + + # Tally self-references + for ref in references: + matched = find_matching(paper.get("authors", []), ref.get("authors", [])) + for name in matched: + author_counts[name] += 1 + + # Compute percentages per author + ratios: Dict[str, float] = { + name: round((count / total_refs * 100), 2) if total_refs > 0 else 0.0 + for name, count in author_counts.items() + } + + # Compute average score + avg_score = round(float(np.mean(list(ratios.values()))) if ratios else 0.0, 3) + + return ReferenceResult( + ssid=identifier, + title=paper.get("title", ""), + num_references=total_refs, + self_references=ratios, + reference_score=avg_score, + ) @optional_async -@retry_with_exponential_backoff(max_retries=4, base_delay=1.0) async def self_references_paper( inputs: Union[str, List[str]], verbose: bool = False ) -> Union[ReferenceResult, List[ReferenceResult]]: @@ -120,15 +136,25 @@ async def self_references_paper( async with httpx.AsyncClient(timeout=httpx.Timeout(20)) as client: tasks = [_process_single_reference(client, ident) for ident in identifiers] - results = await asyncio.gather(*tasks) + results: List[ReferenceResult] = [] + + iterator = asyncio.as_completed(tasks) + if verbose: + iterator = tqdm( + iterator, total=len(tasks), desc="Collecting self-references" + ) + + for coro in iterator: + res = await coro + results.append(res) if verbose: for res in results: logger.info( - f'Self-references in "{res.ssid}": N={res.num_references}, ' + f'Self-references in "{res.title}": N={res.num_references}, ' f"Score={res.reference_score}%" ) for author, pct in res.self_references.items(): - logger.info(f" {author}: {pct}% self-reference") + logger.info(f" {author}: {pct}% self-references") return results[0] if single_input else results diff --git a/paperscraper/citations/tests/test_citations.py b/paperscraper/citations/tests/test_citations.py index 4a6902c..93bd27f 100644 --- a/paperscraper/citations/tests/test_citations.py +++ b/paperscraper/citations/tests/test_citations.py @@ -1,6 +1,7 @@ import logging from paperscraper.citations import get_citations_by_doi +from paperscraper.citations.utils import check_overlap, author_name_to_ssaid logging.disable(logging.INFO) @@ -13,3 +14,19 @@ def test_citations(self): # Try invalid DOI num = get_citations_by_doi("10.1035348/s42256-023-00639-z") assert isinstance(num, int) and num == 0 + + def test_author_name_to_ssid(self): + + ssaid, name = author_name_to_ssaid('Fabian H Sinz') + assert ssaid == '50095217' + assert name == 'Fabian H Sinz' + + def test_name_overlap(self): + assert check_overlap("John Smith", "J. Smith") + assert check_overlap("J. Smith", "John Smith") + assert check_overlap("John A. Smith", "J. Smith") + assert check_overlap("John Smith", "John A. Smith") + assert check_overlap("J A. Smith", "J. Smith") + assert not check_overlap("Alice B. Cooper", "Bob A. Cooper") + assert not check_overlap("Alice Cooper", "Bob A. Cooper") + assert check_overlap("John Walter", "Walter John") diff --git a/paperscraper/citations/tests/test_self_references.py b/paperscraper/citations/tests/test_self_references.py index ff07c15..95763dd 100644 --- a/paperscraper/citations/tests/test_self_references.py +++ b/paperscraper/citations/tests/test_self_references.py @@ -5,6 +5,7 @@ import pytest from paperscraper.citations import self_references_paper +from paperscraper.citations.entity import Researcher from paperscraper.citations.self_references import ReferenceResult logging.disable(logging.INFO) @@ -75,3 +76,49 @@ def test_compare_async_and_sync_performance(self, dois): f"Async execution ({async_duration:.2f}s) is slower than sync execution " f"({sync_duration:.2f}s)" ) + + def test_researcher(self): + """ + Tests calculation of self-references for all papers of an author. + """ + ssaid = "2326988211" + researcher = Researcher(ssaid) + result = researcher.self_references(verbose=True) + assert result.ssid == int(ssaid) + assert isinstance(result.name, str) + assert result.name == "Patrick Soga" + assert isinstance(result.num_references, int) + assert result.num_references > 0 + assert isinstance(result.num_citations, int) + assert result.num_citations == -1 + assert isinstance(result.self_references, Dict) + for title, ratio in result.self_references.items(): + assert isinstance(title, str) + assert isinstance(ratio, float) + assert ratio >= 0 and ratio <= 100 + + assert result.self_reference_ratio >= 0 and result.self_reference_ratio <= 100 + print(result) + + def test_researcher_from_orcid(self): + """ + Tests calculation of self-references for all papers of an author. + """ + orcid = "0000-0003-4221-6988" + researcher = Researcher(orcid) + result = researcher.self_references(verbose=True) + assert result.orcid == orcid + assert isinstance(result.name, str) + assert result.name == "Juan M. Galeazzi" + assert isinstance(result.num_references, int) + assert result.num_references > 0 + assert isinstance(result.num_citations, int) + assert result.num_citations == -1 + assert isinstance(result.self_references, Dict) + for title, ratio in result.self_references.items(): + assert isinstance(title, str) + assert isinstance(ratio, float) + assert ratio >= 0 and ratio <= 100 + + assert result.self_reference_ratio >= 0 and result.self_reference_ratio <= 100 + print(result) \ No newline at end of file diff --git a/paperscraper/citations/utils.py b/paperscraper/citations/utils.py index 9710206..fa5c402 100644 --- a/paperscraper/citations/utils.py +++ b/paperscraper/citations/utils.py @@ -1,14 +1,16 @@ import logging +import os import re import sys -from time import sleep -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple import httpx import requests from tqdm import tqdm from unidecode import unidecode +from ..async_utils import optional_async, retry_with_exponential_backoff + logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.WARNING) @@ -18,6 +20,12 @@ AUTHOR_URL: str = "https://api.semanticscholar.org/graph/v1/author/search" +SS_API_KEY = os.getenv("SS_API_KEY") +HEADERS: Dict[str, str] = {} +if SS_API_KEY: + HEADERS["x-api-key"] = SS_API_KEY + + def get_doi_from_title(title: str) -> Optional[str]: """ Searches the DOI of a paper based on the paper title @@ -31,6 +39,7 @@ def get_doi_from_title(title: str) -> Optional[str]: response = requests.get( PAPER_URL + "search", params={"query": title, "fields": "externalIds", "limit": 1}, + headers=HEADERS, ) data = response.json() @@ -42,7 +51,8 @@ def get_doi_from_title(title: str) -> Optional[str]: logger.warning(f"Did not find DOI for title={title}") -def get_doi_from_ssid(ssid: str, max_retries: int = 10) -> Optional[str]: +@optional_async +async def get_doi_from_ssid(ssid: str, max_retries: int = 10) -> Optional[str]: """ Given a Semantic Scholar paper ID, returns the corresponding DOI if available. @@ -52,31 +62,34 @@ def get_doi_from_ssid(ssid: str, max_retries: int = 10) -> Optional[str]: Returns: str or None: The DOI of the paper, or None if not found or in case of an error. """ - logger.warning( - "Semantic Scholar API is easily overloaded when passing SS IDs, provide DOIs to improve throughput." - ) - attempts = 0 - for attempt in tqdm( - range(1, max_retries + 1), desc=f"Fetching DOI for {ssid}", unit="attempt" - ): - # Make the GET request to Semantic Scholar. - response = requests.get( - f"{PAPER_URL}{ssid}", params={"fields": "externalIds", "limit": 1} + async with httpx.AsyncClient(timeout=httpx.Timeout(20)) as client: + logger.warning( + "Semantic Scholar API is easily overloaded when passing SS IDs, provide DOIs to improve throughput." ) + attempts = 0 + for attempt in tqdm( + range(1, max_retries + 1), desc=f"Fetching DOI for {ssid}", unit="attempt" + ): + # Make the GET request to Semantic Scholar. + response = await client.get( + f"{PAPER_URL}{ssid}", + params={"fields": "externalIds", "limit": 1}, + headers=HEADERS, + ) - # If successful, try to extract and return the DOI. - if response.status_code == 200: - data = response.json() - doi = data.get("externalIds", {}).get("DOI") - return doi - attempts += 1 - sleep(10) - logger.warning( - f"Did not find DOI for paper ID {ssid}. Code={response.status_code}, text={response.text}" - ) + # If successful, try to extract and return the DOI. + if response.status_code == 200: + data = response.json() + doi = data.get("externalIds", {}).get("DOI") + return doi + attempts += 1 + logger.warning( + f"Did not find DOI for paper ID {ssid}. Code={response.status_code}, text={response.text}" + ) -def get_title_and_id_from_doi(doi: str) -> Dict[str, Any]: +@optional_async +async def get_title_and_id_from_doi(doi: str) -> Dict[str, str] | None: """ Given a DOI, retrieves the paper's title and semantic scholar paper ID. @@ -86,18 +99,19 @@ def get_title_and_id_from_doi(doi: str) -> Dict[str, Any]: Returns: dict or None: A dictionary with keys 'title' and 'ssid'. """ - - # Send the GET request to Semantic Scholar - response = requests.get(f"{PAPER_URL}DOI:{doi}") - if response.status_code == 200: - data = response.json() - return {"title": data.get("title"), "ssid": data.get("paperId")} - logger.warning( - f"Could not get authors & semantic scholar ID for DOI={doi}, {response.status_code}: {response.text}" - ) + async with httpx.AsyncClient(timeout=httpx.Timeout(20)) as client: + # Send the GET request to Semantic Scholar + response = await client.get(f"{PAPER_URL}DOI:{doi}", headers=HEADERS) + if response.status_code == 200: + data = response.json() + return {"title": data.get("title"), "ssid": data.get("paperId")} + logger.warning( + f"Could not get authors & semantic scholar ID for DOI={doi}, {response.status_code}: {response.text}" + ) -def author_name_to_ssaid(author_name: str) -> str: +@optional_async +async def author_name_to_ssaid(author_name: str) -> Optional[Tuple[str, str]]: """ Given an author name, returns the Semantic Scholar author ID. @@ -105,22 +119,25 @@ def author_name_to_ssaid(author_name: str) -> str: author_name (str): The full name of the author. Returns: - str or None: The Semantic Scholar author ID or None if no author is found. + Tuple[str, str] or None: The SS author ID alongside the SS name (may differ + slightly from input name) or None if no author is found. """ + async with httpx.AsyncClient(timeout=httpx.Timeout(20)) as client: + response = await client.get( + AUTHOR_URL, + params={"query": author_name, "fields": "name", "limit": 1}, + headers=HEADERS, + ) + if response.status_code == 200: + data = response.json() + authors = data.get("data", []) + if authors: + # Return the Semantic Scholar author ID from the first result. + return authors[0].get("authorId"), authors[0].get("name") - response = requests.get( - AUTHOR_URL, params={"query": author_name, "fields": "name", "limit": 1} - ) - if response.status_code == 200: - data = response.json() - authors = data.get("data", []) - if authors: - # Return the Semantic Scholar author ID from the first result. - return authors[0].get("authorId") - - logger.error( - f"Error in retrieving name from SS Author ID: {response.status_code} - {response.text}" - ) + logger.error( + f"Error in retrieving name from SS Author ID: {response.status_code} - {response.text}" + ) def determine_paper_input_type(input: str) -> Literal["ssid", "doi", "title"]: @@ -145,6 +162,7 @@ def determine_paper_input_type(input: str) -> Literal["ssid", "doi", "title"]: return mode +@retry_with_exponential_backoff(max_retries=10, base_delay=1.0) async def get_papers_for_author(ss_author_id: str) -> List[str]: """ Given a Semantic Scholar author ID, returns a list of all Semantic Scholar paper IDs for that author. @@ -213,19 +231,47 @@ def find_matching( def check_overlap(n1: str, n2: str) -> bool: """ Check whether two author names are identical. - TODO: This can be made more robust + + Heuristics: + - Case insensitive + - If name sets are identical, a match is assumed (e.g. "John Walter" vs "Walter John"). + - Assume the last token is the surname and require: + * same surname + * both have at least one given name + * first given names are compatible (same, or initial vs full) Args: - n1: first name - n2: second name + n1: first name (e.g., "John A. Smith") + n2: second name (e.g., "J. Smith") Returns: bool: Whether names are identical. """ - # remove initials and check for name intersection - s1 = {w for w in clean_name(n1).split()} - s2 = {w for w in clean_name(n2).split()} - return len(s2) > 0 and len(s1 | s2) == len(s1) + t1 = [w for w in clean_name(n1).split() if w] + t2 = [w for w in clean_name(n2).split() if w] + + if not t1 or not t2: + return False # One name is empty after cleaning + + if set(t1) == set(t2): + return True # Name sets are identical + + # Assume last token is surname + surname1, given1 = t1[-1], t1[:-1] + surname2, given2 = t2[-1], t2[:-1] + + if surname1 != surname2: + return False # Surnames do not match + + if not given1 or not given2: + return False # One name has no given names + + # Compare only the *first* given name; middle names are optional + return ( + given1[0] == given2[0] + or (len(given1[0]) == 1 and given2[0].startswith(given1[0])) + or (len(given2[0]) == 1 and given1[0].startswith(given2[0])) + ) def clean_name(s: str) -> str: diff --git a/requirements.txt b/requirements.txt index a07e1f5..944f024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ seaborn>=0.11.0 matplotlib>=3.3.2 matplotlib-venn>=0.11.5 bs4>=0.0.1 -impact-factor>=1.1.1 +impact-factor>=1.1.1,<1.1.3 thefuzz>=0.20.0 pytest tldextract diff --git a/setup.py b/setup.py index 8c0e29b..5fd1152 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ "matplotlib", "matplotlib_venn", "bs4", - "impact-factor>=1.1.1", + "impact-factor>=1.1.1,<1.1.3", "thefuzz", "pytest", "tldextract",