Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
dependencies = [
"numpy",
"packaging",
"requests"
]

[project.urls]
Expand Down
8 changes: 7 additions & 1 deletion src/probeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,11 @@
generate_multi_columns_probe,
generate_multi_shank,
)
from .library import get_probe
from .library import (
get_probe,
list_manufacturers_in_library,
list_probes_in_library,
get_tags_in_library,
cache_full_library,
)
from .wiring import get_available_pathways
206 changes: 187 additions & 19 deletions src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,33 @@

from __future__ import annotations
import os
import warnings
from pathlib import Path
from urllib.request import urlopen
import requests
from typing import Optional

from .io import read_probeinterface

# OLD URL on gin
# public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/"

# Now on github since 2023/06/15
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/"
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/"


# check this for windows and osx
cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"
def get_cache_folder() -> Path:
"""Get the cache folder for probeinterface library files.

Returns
-------
cache_folder : Path
The path to the cache folder.
"""
return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"

def download_probeinterface_file(manufacturer: str, probe_name: str):

def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None:
"""Download the probeinterface file to the cache directory.
Note that the file is itself a ProbeGroup but on the repo each file
represents one probe.
Expand All @@ -38,16 +48,24 @@ def download_probeinterface_file(manufacturer: str, probe_name: str):
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe
"""
os.makedirs(cache_folder / manufacturer, exist_ok=True)
localfile = cache_folder / manufacturer / (probe_name + ".json")
distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json"
dist = urlopen(distantfile)
with open(localfile, "wb") as f:
f.write(dist.read())
cache_folder = get_cache_folder()
if tag is not None:
assert tag in get_tags_in_library(), f"Tag {tag} not found in library"
else:
tag = "main"

os.makedirs(cache_folder / tag / manufacturer, exist_ok=True)
local_file = cache_folder / tag / manufacturer / (probe_name + ".json")
remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json"
rem = urlopen(remote_file)
with open(local_file, "wb") as f:
f.write(rem.read())

def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:

def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Get Probe from local cache

Expand All @@ -57,24 +75,72 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

localfile = cache_folder / manufacturer / (probe_name + ".json")
if not localfile.is_file():
local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if not local_file.is_file():
return None
else:
probegroup = read_probeinterface(localfile)
probegroup = read_probeinterface(local_file)
probe = probegroup.probes[0]
probe._probe_group = None
return probe


def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe":
def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Remove Probe from local cache

Parameters
----------
manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps"
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if local_file.is_file():
os.remove(local_file)


def get_probe(
manufacturer: str,
probe_name: str,
name: Optional[str] = None,
tag: Optional[str] = None,
force_download: bool = False,
) -> "Probe":
"""
Get probe from ProbeInterface library

Expand All @@ -86,21 +152,123 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) ->
The probe name
name : str | None, default: None
Optional name for the probe
tag : str | None, default: None
Optional tag for the probe
force_download : bool, default: False
If True, force re-download of the probe file.

Returns
----------
probe : Probe object

"""

probe = get_from_cache(manufacturer, probe_name)
if not force_download:
probe = get_from_cache(manufacturer, probe_name, tag=tag)
else:
probe = None

if probe is None:
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
if probe.manufacturer == "":
probe.manufacturer = manufacturer
if name is not None:
probe.name = name

return probe


def cache_full_library(tag=None) -> None:
"""
Download all probes from the library to the cache directory.
"""
manufacturers = list_manufacturers_in_library(tag=tag)

for manufacturer in manufacturers:
probes = list_probes_in_library(manufacturer, tag=tag)
for probe_name in probes:
try:
download_probeinterface_file(manufacturer, probe_name, tag=tag)
except Exception as e:
warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}")


def list_manufacturers_in_library(tag=None) -> list[str]:
"""
Get the list of available manufacturers in the library

Returns
-------
manufacturers : list of str
List of available manufacturers
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
"""
if tag is not None:
assert tag in get_tags_in_library(), (
f"Tag {tag} not found in library. Available tags are {get_tags_in_library()}."
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag)


def list_probes_in_library(manufacturer: str, tag=None) -> list[str]:
"""
Get the list of available probes for a given manufacturer

Parameters
----------
manufacturer : str
The probe manufacturer

Returns
-------
probes : list of str
List of available probes for the given manufacturer
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
"""
if tag is not None:
assert tag in get_tags_in_library(), (
f"Tag {tag} not found in library. Available tags are {get_tags_in_library()}."
)
if manufacturer is not None:
assert manufacturer in list_manufacturers(tag=tag), (
f"Manufacturer {manufacturer} not found in library. Available manufacturers are {list_manufacturers()}."
)

I'd be keen for some better error messaging due to my own spelling errors...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and done

return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag)


def get_tags_in_library() -> list[str]:
"""
Get the list of available tags in the library

Returns
-------
tags : list of str
List of available tags
"""
tags = get_all_tags("SpikeInterface", "probeinterface_library")
return tags


### UTILS
def get_all_tags(owner: str, repo: str, token: str = None):
"""
Get all tags for a repo.
Returns a list of tag names, or an empty list if no tags exist.
"""
url = f"https://api.github.com/repos/{owner}/{repo}/tags"
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}")
tags = resp.json()
return [tag["name"] for tag in tags]


def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None):
"""
Return a list of directory names in the given repo at the specified path.
You can pass a branch, tag, or commit SHA via `ref`.
If token is provided, use it for authenticated requests (higher rate limits).
"""
url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers, params=params)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}")
items = resp.json()
return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."]
52 changes: 44 additions & 8 deletions tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
from probeinterface import Probe
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe


from pathlib import Path
import numpy as np

import pytest
from probeinterface.library import (
download_probeinterface_file,
get_from_cache,
remove_from_cache,
get_probe,
get_tags_in_library,
list_manufacturers_in_library,
list_probes_in_library,
get_cache_folder,
)


manufacturer = "neuronexus"
probe_name = "A1x32-Poly3-10mm-50-177"


def test_download_probeinterface_file():
download_probeinterface_file(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=None)


def test_get_from_cache():
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
assert isinstance(probe, Probe)

tag = get_tags_in_library()[0]
probe = get_from_cache(manufacturer, probe_name, tag=tag)
assert probe is None # because we did not download with this tag
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
remove_from_cache(manufacturer, probe_name, tag=tag)
assert isinstance(probe, Probe)

probe = get_from_cache("yep", "yop")
assert probe is None

Expand All @@ -31,7 +42,32 @@ def test_get_probe():
assert probe.get_contact_count() == 32


def test_available_tags():
tags = get_tags_in_library()
if len(tags) > 0:
for tag in tags:
assert isinstance(tag, str)
assert len(tag) > 0


def test_list_manufacturers_in_library():
manufacturers = list_manufacturers_in_library()
assert isinstance(manufacturers, list)
assert "neuronexus" in manufacturers
assert "imec" in manufacturers


def test_list_probes_in_library():
manufacturers = list_manufacturers_in_library()
for manufacturer in manufacturers:
probes = list_probes_in_library(manufacturer)
assert isinstance(probes, list)
assert len(probes) > 0


if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
test_list_manufacturers_in_library()
test_list_probes_in_library()