-
Couldn't load subscription status.
- Fork 176
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
21d5882
7cd6d69
eff9dde
b230d11
692a270
b52e5b7
806becf
98d249a
cdc4fdd
030f985
5aff4d6
ba74743
a749637
c716fb5
951c026
d8adf27
0e410a4
8ea34e8
96992d5
460428a
dd3b867
743ebb3
5a6c825
90d9e6a
787feb0
dea107d
cdd3747
b04c8ef
383c445
eef9015
2d2275b
a2a8606
dcbd235
c227265
27a50ff
c2f0539
310e9e9
fe14723
7be3c71
0e01d1d
229dd57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| import pandas as pd | ||
| from scipy.sparse import issparse | ||
|
|
||
| from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray | ||
| from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp | ||
| from .xarray import Dataset2D | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -37,10 +37,11 @@ def _normalize_indices( | |
| return ax0, ax1 | ||
|
|
||
|
|
||
| def _normalize_index( # noqa: PLR0911, PLR0912 | ||
| def _normalize_index( # noqa: PLR0911, PLR0912, PLR0915 | ||
| indexer: Index1D, index: pd.Index | ||
| ) -> Index1DNorm | int | np.integer: | ||
| # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. | ||
| # protect aroound weird numeric index | ||
| if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): | ||
| msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" | ||
| raise TypeError(msg) | ||
|
|
@@ -52,6 +53,7 @@ def name_idx(i): | |
| i = index.get_loc(i) | ||
| return i | ||
|
|
||
| # converting start and stop of the slide to the integer positions if they are strings | ||
| if isinstance(indexer, slice): | ||
| start = name_idx(indexer.start) | ||
| stop = name_idx(indexer.stop) | ||
|
|
@@ -67,17 +69,21 @@ def name_idx(i): | |
| elif isinstance( | ||
| indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray | ||
| ): | ||
| # convert to the 1D if it's accidentally 2D column/row vector | ||
| # convert sparse into dense arrays if needed | ||
| if hasattr(indexer, "shape") and ( | ||
| (indexer.shape == (index.shape[0], 1)) | ||
| or (indexer.shape == (1, index.shape[0])) | ||
| ): | ||
| if isinstance(indexer, CSMatrix | CSArray): | ||
| indexer = indexer.toarray() | ||
| indexer = np.ravel(indexer) | ||
| # if it is something else, convert it to numpy | ||
| if not isinstance(indexer, np.ndarray | pd.Index): | ||
| indexer = np.array(indexer) | ||
| if len(indexer) == 0: | ||
| indexer = indexer.astype(int) | ||
| # if it is a float array or something along those lines, convert it to integers | ||
| if isinstance(indexer, np.ndarray) and np.issubdtype( | ||
| indexer.dtype, np.floating | ||
| ): | ||
|
|
@@ -96,7 +102,7 @@ def name_idx(i): | |
| ) | ||
| raise IndexError(msg) | ||
| return indexer | ||
| else: # indexer should be string array | ||
| else: | ||
| positions = index.get_indexer(indexer) | ||
| if np.any(positions < 0): | ||
| not_found = indexer[positions < 0] | ||
|
|
@@ -110,8 +116,65 @@ def name_idx(i): | |
| if isinstance(indexer.data, DaskArray): | ||
| return indexer.data.compute() | ||
| return indexer.data | ||
|
|
||
| elif has_xp(indexer): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why. |
||
| # getting array's namespace | ||
| xp = indexer.__array_namespace__() | ||
|
|
||
| # Flatten to 1D | ||
| if hasattr(indexer, "shape") and ( | ||
| indexer.shape == (index.shape[0], 1) or indexer.shape == (1, index.shape[0]) | ||
| ): | ||
| indexer = xp.ravel( | ||
| indexer | ||
| ) # flattening to 1D, jax.numpy has it, not sure about cubed | ||
|
|
||
| # Get dtype in array-api-style | ||
| dtype = getattr(indexer, "dtype", None) | ||
|
|
||
| # if we have like a jax boolean mask array | ||
| if xp.issubdtype(dtype, xp.bool_): | ||
| if indexer.shape != index.shape: | ||
| msg = ( | ||
| f"Boolean index does not match AnnData’s shape along this dimension. " | ||
| f"Boolean index has shape {indexer.shape}, expected {index.shape}" | ||
| ) | ||
| raise IndexError(msg) | ||
| return indexer | ||
|
|
||
| # all good, you can return it | ||
| elif xp.issubdtype(dtype, xp.integer): | ||
| return indexer | ||
| # float number case | ||
| elif xp.issubdtype(dtype, xp.floating): | ||
| indexer_int = xp.astype(indexer, xp.int32) # jax default to it | ||
| # If all floats were “safe” (like 0.0, 1.0, 2.0), return them cast to integers. | ||
| is_fractional = xp.not_equal(indexer, xp.astype(indexer_int, xp.floating)) | ||
| if xp.any(is_fractional): | ||
| msg = f"Indexer {indexer!r} has non-integer floating point values." | ||
| raise IndexError(msg) | ||
| return indexer_int | ||
|
|
||
| else: | ||
| try: | ||
| values = indexer.tolist() # converting to the list | ||
| except Exception as err: | ||
| msg = f"Could not convert {indexer!r} to list for string lookup." | ||
| raise IndexError(msg) from err | ||
| positions = index.get_indexer(values) | ||
| if np.any(positions < 0): | ||
| not_found = [ | ||
| v for v, p in zip(values, positions, strict=False) if p < 0 | ||
| ] | ||
| msg = ( | ||
| f"Values {not_found}, from {values}, " | ||
| "are not valid obs/ var names or indices." | ||
| ) | ||
| raise KeyError(msg) | ||
| return positions | ||
|
|
||
| msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" | ||
| raise IndexError() | ||
| raise IndexError(msg) | ||
|
|
||
|
|
||
| def _fix_slice_bounds(s: slice, length: int) -> slice: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle