-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…a into ig/array_api_continue import merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I just went over general code style, nothing JAX-related
src/anndata/_core/merge.py
Outdated
| # Force to NumPy (materializes JAX/Cubed); fine for small tests, | ||
| # but may be slow or fail on large/lazy arrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code doesn’t just run for tests though. Also are you sure that this is a good idea for arrays with pandas dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was initially forcing everything to NumPy, but that’s no longer the case. I’ve updated it so the it should preserve arrays with pandas dtypes.
src/anndata/_core/merge.py
Outdated
| return False | ||
|
|
||
|
|
||
| def _to_numpy_if_array_api(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be no second copy of this that’s slightly different, only one!
| dest = self._adata_ref._X | ||
| # Handles read-only NumPy views from backend arrays like JAX by | ||
| # making a writable copy so in-place assignment on views can succeed. | ||
| if isinstance(dest, np.ndarray) and not dest.flags.writeable: | ||
| dest = np.array(dest, copy=True) # make a fresh, writable buffer | ||
| self._adata_ref._X = dest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle
src/anndata/_core/merge.py
Outdated
| hasattr(x, "dtype") and is_extension_array_dtype(x.dtype) | ||
| ): | ||
| return x | ||
| return np.asarray(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray but dlpack to do the conversion. In short:
- We should have a check in
_apply_to_arrayto see if something is array-api compatible but not a numpy ndarray. - If this case is true, dlpack into numpy, recursively call
_apply_to_array - Then use dlpack to take the output of the recursive call to the original type before we went to numpy.
Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.
… with copying introduced as an extra precaution
src/anndata/_core/merge.py
Outdated
| def _dlpack_from_numpy(x_np, original_xp): | ||
| # cubed and other array later elif | ||
| if original_xp.__name__.startswith("jax"): | ||
| return jax.dlpack.from_dlpack(x_np) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/anndata/_core/merge.py
Outdated
|
|
||
| T = TypeVar("T") | ||
|
|
||
| with suppress(ImportError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/anndata/_core/merge.py
Outdated
| # Use the backend of the first array as the reference | ||
| ref = arrays[0] | ||
| xp = get_namespace(ref) | ||
|
|
||
| # Convert all arrays to the same backend as `ref` | ||
| arrays = [ref] + [_same_backend(ref, x, copy=True)[1] for x in arrays[1:]] | ||
|
|
||
| # Concatenate with the backend’s API | ||
| value = xp.concatenate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition was previously hit by the fact that none of the above checks involving any were True. Instead of changing this last default condition, I would create a new branch here specifically for the array-api, check that they all have the same backend, and the concatenate. If they don't have the same backed, you just proceed to the np condition (which will fail presumably). I wouldn't worry about mixing different backends, especially with the array-api for now. If we use cubed, dlpack won't work there anyway
src/anndata/_core/merge.py
Outdated
|
|
||
| # fallback for known backends that put it elsewhere (JAX and later others) | ||
| if original_xp.__name__.startswith("jax"): | ||
| import jax.dlpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original_xp should have a from_dlpack method! Does my comment here 383c445#r2291442475 not apply?
|
|
||
|
|
||
| def test_write_large_categorical(tmp_path, diskfmt): | ||
| @pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would revert this - it's just for generating categories which gets pushed into pandas. I don't think this triggers any internal array-api code
| return indexer.data | ||
| return indexer.dat | ||
|
|
||
| elif has_xp(indexer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.
src/anndata/_core/merge.py
Outdated
| return pd.api.extensions.take( | ||
| el, indexer, axis=axis, allow_fill=True, fill_value=fill_value | ||
| ) | ||
| if _is_pandas(el): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if _is_pandas(el): | |
| if isinstance(el, np.ndarray): |
I would have thought that el is a numpy array given that the old function name was _apply_to_array, no?
src/anndata/_core/merge.py
Outdated
| # reverting back to numpy as it is hard to reindex on JAX and others | ||
| return _dlpack_from_numpy(out_np, xp) | ||
|
|
||
| # numpy case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic here is a little confused. This function used to be for numpy, but you added the _is_pandas check above, which I don't think applies here. But the logic you've written "# numpy case" and down works great for non-numpy array-api compatible arrays as well!
So I would leave the numpy case as before (i.e., remove _is_pandas and check isinstance(el, np.ndarray)), and then in the case it is not a numpy array, use this logic under "# numpy case"! You can then get rid of the if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): branch. One reason I cautioned against falling back to numpy behavior is that some things like jax arrays that are API compatible might be on the GPU! You can't transfer a JAX array on the GPU to numpy :/
src/anndata/_core/views.py
Outdated
| return old[new] | ||
|
|
||
| # handle boolean mask; i.e. checking whether old is a boolean array | ||
| if hasattr(old, "dtype") and str(old.dtype) in ("bool", "bool_", "boolean"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If old is not array-api compatible, shouldn't we error out? old.dtype would exist and https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html#isdtype could handle checking for bool, no need for strings or anything
src/anndata/_core/merge.py
Outdated
| # Check: is array-api compatible, but not NumPy | ||
| if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): | ||
| # Convert to NumPy via DLPack | ||
| el_np = _dlpack_to_numpy(el) | ||
| # Recursively call this same function | ||
| out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
| # reverting back to numpy as it is hard to reindex on JAX and others | ||
| return _dlpack_from_numpy(out_np, xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why convert to numpy here? Just let the array pass through, that's the whole point of using the array api :) If a jax array is on the GPU you're going to bring it to the CPU here, but why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if this is going in circles, but I'm kind of reviewing locally to changes, and losing the big picture sometimes!
src/anndata/_core/merge.py
Outdated
| # Check: is array-api compatible, but not NumPy | ||
| if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): | ||
| # Convert to NumPy via DLPack | ||
| el_np = _dlpack_to_numpy(el) | ||
| # Recursively call this same function | ||
| out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
| # reverting back to numpy as it is hard to reindex on JAX and others | ||
| return _dlpack_from_numpy(out_np, xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!
First step in getting anndata concat and test generation to work properly with JAX, (and Cubed potentially), without just converting everything into NumPy.
Random data creation and shape handling use xp.asarray so arrays stay in their original backend where possible. I also updated concat paths to actually check types before converting, added helpers for sparse detection and array API checks, and made sure backend arrays only get turned into NumPy when absolutely necessary. This fixes a bunch of concat-related test failures.
It’s still not perfect. Some pandas calls in concat still force conversion to NumPy, so the data gets copied instead of being used directly. Cubed support is only a placeholder right now. Type detection might still be a bit too broad, which can lead to extra conversions. Works for NumPy and JAX in tests, but I haven’t tried other backends.