Skip to content

Commit 5406404

Browse files
authored
Merge pull request #37 from HzaCode/fix-mypy-type-annotations
Fix mypy type annotations
2 parents 4e3f78d + e33c787 commit 5406404

File tree

5 files changed

+92
-47
lines changed

5 files changed

+92
-47
lines changed

src/ChemInformant/api_helpers.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_session() -> requests_cache.CachedSession:
8282
global _session
8383
if _session is None:
8484
setup_cache()
85+
assert _session is not None # Type narrowing for mypy
8586
return _session
8687

8788

@@ -125,18 +126,20 @@ def _fetch_with_ratelimit_and_retry(
125126
time.sleep(MIN_REQUEST_INTERVAL - elapsed)
126127
last_api_call_time = time.time()
127128

128-
retries, backoff = 0, INITIAL_BACKOFF
129+
retries = 0
130+
backoff = float(INITIAL_BACKOFF)
129131
while retries < MAX_RETRIES:
130132
try:
131133
resp = _execute_fetch(url)
132134

133135
# Bust cache for stale 503 errors. If a "server busy" response
134136
# came from the cache, delete it and try a live request.
135137
if getattr(resp, "from_cache", False) and resp.status_code == 503:
136-
key = get_session().cache.create_key(resp.request)
137-
get_session().cache.delete(key)
138-
with get_session().cache.disabled():
139-
resp = _execute_fetch(url)
138+
session = get_session()
139+
key = session.cache.create_key(resp.request)
140+
session.cache.delete(key)
141+
# Directly fetch without cache
142+
resp = session.get(url, timeout=REQUEST_TIMEOUT)
140143

141144
if resp.status_code == 200:
142145
ctype = resp.headers.get("Content-Type", "").lower()
@@ -162,7 +165,7 @@ def _fetch_with_ratelimit_and_retry(
162165
time.sleep(backoff)
163166
# Note: random.uniform is safe here as it's only used for jitter in retry delays
164167
backoff = min(MAX_BACKOFF, backoff * 2) + random.uniform(
165-
0, 1
168+
0.0, 1.0
166169
) # Exponential backoff with jitter
167170
retries += 1
168171

@@ -201,7 +204,10 @@ def get_cids_by_name(name: str) -> list[int] | None:
201204
"""
202205
url = f"{PUBCHEM_API_BASE}/compound/name/{quote(name)}/cids/JSON"
203206
data = _fetch_with_ratelimit_and_retry(url)
204-
return data.get("IdentifierList", {}).get("CID") if isinstance(data, dict) else None
207+
if isinstance(data, dict):
208+
cid_list = data.get("IdentifierList", {}).get("CID")
209+
return cid_list if isinstance(cid_list, list) else None
210+
return None
205211

206212

207213
def get_cids_by_smiles(smiles: str) -> list[int] | None:
@@ -231,7 +237,10 @@ def get_cids_by_smiles(smiles: str) -> list[int] | None:
231237
"""
232238
url = f"{PUBCHEM_API_BASE}/compound/smiles/{quote(smiles)}/cids/JSON"
233239
data = _fetch_with_ratelimit_and_retry(url)
234-
return data.get("IdentifierList", {}).get("CID") if isinstance(data, dict) else None
240+
if isinstance(data, dict):
241+
cid_list = data.get("IdentifierList", {}).get("CID")
242+
return cid_list if isinstance(cid_list, list) else None
243+
return None
235244

236245

237246
def get_batch_properties(
@@ -352,7 +361,12 @@ def get_cas_for_cid(cid: int) -> str | None:
352361
"StringWithMarkup"
353362
)
354363
if markup and isinstance(markup, list) and markup:
355-
return markup[0].get("String")
364+
string_val = markup[0].get("String")
365+
return (
366+
string_val
367+
if isinstance(string_val, str)
368+
else None
369+
)
356370
return None
357371

358372

@@ -386,5 +400,6 @@ def get_synonyms_for_cid(cid: int) -> list[str]:
386400
if isinstance(data, dict):
387401
info_list = data.get("InformationList", {}).get("Information", [])
388402
if info_list and isinstance(info_list[0].get("Synonym"), list):
389-
return info_list[0]["Synonym"]
403+
synonyms = info_list[0]["Synonym"]
404+
return synonyms if isinstance(synonyms, list) else []
390405
return []

src/ChemInformant/cheminfo_api.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_properties(
6161
namespace: str = "cid",
6262
include_3d: bool = False,
6363
all_properties: bool = False,
64-
**kwargs,
64+
**kwargs: Any,
6565
) -> pd.DataFrame:
6666
"""
6767
Retrieve chemical properties for one or more compounds from PubChem.
@@ -185,19 +185,22 @@ def get_properties(
185185
if properties is not None and len(resolved_props) == 0:
186186
return pd.DataFrame()
187187

188+
identifiers_list: list[int | str]
188189
if not isinstance(identifiers, list):
189-
identifiers = [identifiers]
190+
identifiers_list = [identifiers]
191+
else:
192+
identifiers_list = list(identifiers) # Type-safe conversion
190193

191194
# --- Step 2: Create base DataFrame with resolved CIDs ---
192195
meta: dict[Any, dict[str, Any]] = {}
193-
for ident in identifiers:
196+
for ident in identifiers_list:
194197
try:
195198
cid = _resolve_to_single_cid(ident)
196199
meta[ident] = {"status": "OK", "cid": str(cid)}
197200
except Exception as exc:
198201
meta[ident] = {"status": type(exc).__name__, "cid": pd.NA}
199202
df = pd.DataFrame(
200-
[{"input_identifier": ident, **meta[ident]} for ident in identifiers]
203+
[{"input_identifier": ident, **meta[ident]} for ident in identifiers_list]
201204
)
202205

203206
# --- Early return for empty inputs ---
@@ -232,9 +235,10 @@ def get_properties(
232235

233236
for prop_snake in standard_props_snake:
234237
prop_camel = SNAKE_TO_CAMEL[prop_snake]
235-
values = []
238+
values: list[Any] = []
236239
for _, row in df.iterrows():
237-
cid = int(row["cid"]) if pd.notna(row["cid"]) else None
240+
cid_val = row["cid"]
241+
cid = int(cid_val) if pd.notna(cid_val) else None
238242
val = None
239243
if row["status"] == "OK" and cid:
240244
api_row = fetched_data.get(cid, {})
@@ -255,7 +259,7 @@ def get_properties(
255259
if prop_snake == "cas"
256260
else api_helpers.get_synonyms_for_cid
257261
)
258-
fail_value = None if prop_snake == "cas" else []
262+
fail_value: Any = None if prop_snake == "cas" else []
259263
prop_data = [
260264
fetch_func(int(cid)) if pd.notna(cid) and status == "OK" else fail_value
261265
for cid, status in zip(df["cid"], df["status"])
@@ -281,7 +285,7 @@ def get_properties(
281285
# --- Convenience Functions (now simple and consistent) ---
282286

283287

284-
def _fetch_scalar(id_, prop_snake):
288+
def _fetch_scalar(id_: str | int, prop_snake: str) -> float | int | str | None:
285289
"""Internal helper for single-value convenience functions."""
286290
try:
287291
cid = _resolve_to_single_cid(id_)
@@ -312,9 +316,13 @@ def _fetch_scalar(id_, prop_snake):
312316
props = api_helpers.get_batch_properties([cid], props_to_fetch)
313317
data = props.get(cid, {})
314318

315-
return data.get(prop_camel) or (
319+
result = data.get(prop_camel) or (
316320
data.get(fallback_camel) if fallback_camel else None
317321
)
322+
# Ensure we return the correct type
323+
if result is not None and not isinstance(result, (int, float, str)):
324+
return None
325+
return result
318326
except (NotFoundError, AmbiguousIdentifierError):
319327
return None
320328

@@ -335,7 +343,8 @@ def get_weight(id_: str | int) -> float | None:
335343
>>> get_weight(2244) # Same as above using CID
336344
180.16
337345
"""
338-
return _fetch_scalar(id_, "molecular_weight")
346+
result = _fetch_scalar(id_, "molecular_weight")
347+
return float(result) if isinstance(result, (int, float)) else None
339348

340349

341350
def get_formula(id_: str | int) -> str | None:
@@ -354,7 +363,8 @@ def get_formula(id_: str | int) -> str | None:
354363
>>> get_formula("water")
355364
'H2O'
356365
"""
357-
return _fetch_scalar(id_, "molecular_formula")
366+
result = _fetch_scalar(id_, "molecular_formula")
367+
return str(result) if result is not None else None
358368

359369

360370
def get_canonical_smiles(id_: str | int) -> str | None:
@@ -376,7 +386,8 @@ def get_canonical_smiles(id_: str | int) -> str | None:
376386
>>> get_canonical_smiles(2244)
377387
'CC(=O)OC1=CC=CC=C1C(=O)O'
378388
"""
379-
return _fetch_scalar(id_, "canonical_smiles")
389+
result = _fetch_scalar(id_, "canonical_smiles")
390+
return str(result) if result is not None else None
380391

381392

382393
def get_isomeric_smiles(id_: str | int) -> str | None:
@@ -396,7 +407,8 @@ def get_isomeric_smiles(id_: str | int) -> str | None:
396407
>>> get_isomeric_smiles("glucose")
397408
'C([C@@H]1[C@H]([C@@H]([C@H]([C@H](O1)O)O)O)O)O'
398409
"""
399-
return _fetch_scalar(id_, "isomeric_smiles")
410+
result = _fetch_scalar(id_, "isomeric_smiles")
411+
return str(result) if result is not None else None
400412

401413

402414
def get_iupac_name(id_: str | int) -> str | None:
@@ -415,7 +427,8 @@ def get_iupac_name(id_: str | int) -> str | None:
415427
>>> get_iupac_name("water")
416428
'oxidane'
417429
"""
418-
return _fetch_scalar(id_, "iupac_name")
430+
result = _fetch_scalar(id_, "iupac_name")
431+
return str(result) if result is not None else None
419432

420433

421434
def get_xlogp(id_: str | int) -> float | None:
@@ -437,7 +450,8 @@ def get_xlogp(id_: str | int) -> float | None:
437450
>>> get_xlogp("water")
438451
-0.7
439452
"""
440-
return _fetch_scalar(id_, "xlogp")
453+
result = _fetch_scalar(id_, "xlogp")
454+
return float(result) if isinstance(result, (int, float)) else None
441455

442456

443457
def get_cas(id_: str | int) -> str | None:
@@ -508,7 +522,8 @@ def get_exact_mass(id_: str | int) -> float | None:
508522
>>> get_exact_mass("aspirin")
509523
180.04225873
510524
"""
511-
return _fetch_scalar(id_, "exact_mass")
525+
result = _fetch_scalar(id_, "exact_mass")
526+
return float(result) if isinstance(result, (int, float)) else None
512527

513528

514529
def get_monoisotopic_mass(id_: str | int) -> float | None:
@@ -528,7 +543,8 @@ def get_monoisotopic_mass(id_: str | int) -> float | None:
528543
>>> get_monoisotopic_mass("aspirin")
529544
180.04225873
530545
"""
531-
return _fetch_scalar(id_, "monoisotopic_mass")
546+
result = _fetch_scalar(id_, "monoisotopic_mass")
547+
return float(result) if isinstance(result, (int, float)) else None
532548

533549

534550
def get_tpsa(id_: str | int) -> float | None:
@@ -549,7 +565,8 @@ def get_tpsa(id_: str | int) -> float | None:
549565
>>> get_tpsa("aspirin")
550566
63.6
551567
"""
552-
return _fetch_scalar(id_, "tpsa")
568+
result = _fetch_scalar(id_, "tpsa")
569+
return float(result) if isinstance(result, (int, float)) else None
553570

554571

555572
def get_complexity(id_: str | int) -> float | None:
@@ -569,7 +586,8 @@ def get_complexity(id_: str | int) -> float | None:
569586
>>> get_complexity("aspirin")
570587
212
571588
"""
572-
return _fetch_scalar(id_, "complexity")
589+
result = _fetch_scalar(id_, "complexity")
590+
return float(result) if isinstance(result, (int, float)) else None
573591

574592

575593
def get_h_bond_donor_count(id_: str | int) -> int | None:
@@ -589,7 +607,8 @@ def get_h_bond_donor_count(id_: str | int) -> int | None:
589607
>>> get_h_bond_donor_count("aspirin")
590608
1
591609
"""
592-
return _fetch_scalar(id_, "h_bond_donor_count")
610+
result = _fetch_scalar(id_, "h_bond_donor_count")
611+
return int(result) if isinstance(result, (int, float)) else None
593612

594613

595614
def get_h_bond_acceptor_count(id_: str | int) -> int | None:
@@ -609,7 +628,8 @@ def get_h_bond_acceptor_count(id_: str | int) -> int | None:
609628
>>> get_h_bond_acceptor_count("aspirin")
610629
4
611630
"""
612-
return _fetch_scalar(id_, "h_bond_acceptor_count")
631+
result = _fetch_scalar(id_, "h_bond_acceptor_count")
632+
return int(result) if isinstance(result, (int, float)) else None
613633

614634

615635
def get_rotatable_bond_count(id_: str | int) -> int | None:
@@ -629,7 +649,8 @@ def get_rotatable_bond_count(id_: str | int) -> int | None:
629649
>>> get_rotatable_bond_count("aspirin")
630650
3
631651
"""
632-
return _fetch_scalar(id_, "rotatable_bond_count")
652+
result = _fetch_scalar(id_, "rotatable_bond_count")
653+
return int(result) if isinstance(result, (int, float)) else None
633654

634655

635656
def get_heavy_atom_count(id_: str | int) -> int | None:
@@ -649,7 +670,8 @@ def get_heavy_atom_count(id_: str | int) -> int | None:
649670
>>> get_heavy_atom_count("aspirin")
650671
13
651672
"""
652-
return _fetch_scalar(id_, "heavy_atom_count")
673+
result = _fetch_scalar(id_, "heavy_atom_count")
674+
return int(result) if isinstance(result, (int, float)) else None
653675

654676

655677
def get_charge(id_: str | int) -> int | None:
@@ -669,7 +691,8 @@ def get_charge(id_: str | int) -> int | None:
669691
>>> get_charge("aspirin")
670692
0
671693
"""
672-
return _fetch_scalar(id_, "charge")
694+
result = _fetch_scalar(id_, "charge")
695+
return int(result) if isinstance(result, (int, float)) else None
673696

674697

675698
def get_atom_stereo_count(id_: str | int) -> int | None:
@@ -689,7 +712,8 @@ def get_atom_stereo_count(id_: str | int) -> int | None:
689712
>>> get_atom_stereo_count("glucose")
690713
4
691714
"""
692-
return _fetch_scalar(id_, "atom_stereo_count")
715+
result = _fetch_scalar(id_, "atom_stereo_count")
716+
return int(result) if isinstance(result, (int, float)) else None
693717

694718

695719
def get_bond_stereo_count(id_: str | int) -> int | None:
@@ -709,7 +733,8 @@ def get_bond_stereo_count(id_: str | int) -> int | None:
709733
>>> get_bond_stereo_count("retinol")
710734
4
711735
"""
712-
return _fetch_scalar(id_, "bond_stereo_count")
736+
result = _fetch_scalar(id_, "bond_stereo_count")
737+
return int(result) if isinstance(result, (int, float)) else None
713738

714739

715740
def get_covalent_unit_count(id_: str | int) -> int | None:
@@ -729,7 +754,8 @@ def get_covalent_unit_count(id_: str | int) -> int | None:
729754
>>> get_covalent_unit_count("aspirin")
730755
1
731756
"""
732-
return _fetch_scalar(id_, "covalent_unit_count")
757+
result = _fetch_scalar(id_, "covalent_unit_count")
758+
return int(result) if isinstance(result, (int, float)) else None
733759

734760

735761
def get_inchi(id_: str | int) -> str | None:
@@ -749,7 +775,8 @@ def get_inchi(id_: str | int) -> str | None:
749775
>>> get_inchi("aspirin")
750776
'InChI=1S/C9H8O4/c1-6(10)13-8-5-3-2-4-7(8)9(11)12/h2-5H,1H3,(H,11,12)'
751777
"""
752-
return _fetch_scalar(id_, "in_ch_i")
778+
result = _fetch_scalar(id_, "in_ch_i")
779+
return str(result) if result is not None else None
753780

754781

755782
def get_inchi_key(id_: str | int) -> str | None:
@@ -769,7 +796,8 @@ def get_inchi_key(id_: str | int) -> str | None:
769796
>>> get_inchi_key("aspirin")
770797
'BSYNRYMUTXBXSQ-UHFFFAOYSA-N'
771798
"""
772-
return _fetch_scalar(id_, "in_ch_i_key")
799+
result = _fetch_scalar(id_, "in_ch_i_key")
800+
return str(result) if result is not None else None
773801

774802

775803
def get_compound(identifier: str | int) -> Compound:
@@ -802,7 +830,7 @@ def get_compound(identifier: str | int) -> Compound:
802830
This function uses CamelCase property names to match the Compound model.
803831
For DataFrame output with snake_case names, use get_properties() instead.
804832
"""
805-
df = get_properties([identifier], all_properties=True)
833+
df = get_properties(identifier, all_properties=True)
806834
if df.empty or df["status"].iat[0] != "OK":
807835
raise RuntimeError(f"Failed to fetch compound for {identifier!r}")
808836

@@ -843,7 +871,7 @@ def get_compounds(identifiers: Iterable[str | int]) -> list[Compound]:
843871
return [get_compound(x) for x in identifiers]
844872

845873

846-
def draw_compound(identifier: str | int):
874+
def draw_compound(identifier: str | int) -> None:
847875
"""
848876
Draw the 2D chemical structure of a compound.
849877

0 commit comments

Comments
 (0)