-
Notifications
You must be signed in to change notification settings - Fork 11
gh-405: array API support for glass.core.algorithm
#423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Perhaps, before merging any changes, we should figure out a way to
|
glass/core/algorithm.py
Outdated
msg = "input arrays should belong to the same array library" | ||
raise ValueError(msg) | ||
|
||
xp = a.__array_namespace__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was talking to @matt-graham and he wondered if we should be using array-api-compat here. That should solve the cupy
issue. I believe only NumPy
and JAX
are stable at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CuPy support will become official in their next major release (v14
) where they will start supporting NumPy v2
. I hope they push the release in the next 3 months, so that we don't have to experiment with array-api-compat.
6e8ef45
to
761795c
Compare
fa75b7a
to
1b0d92b
Compare
ee54338
to
536f1ab
Compare
536f1ab
to
038bbbf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took some time to revive this PR 😅
def nnls( | ||
a: NDArray[np.float64], | ||
b: NDArray[np.float64], | ||
a: NDArray[np.float64] | ArrayLike, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #415
a = rng.uniform(low=-10, high=10, size=[50, 10]) | ||
b = np.abs(rng.uniform(low=-2, high=2, size=[10])) | ||
b = xp.abs(rng.uniform(low=-2, high=2, size=[10])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions might not use RNGs but their tests might, so we do need to come up with a consensus on RNGs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess for tests we can always do -
xp.asarray(rng(...)) # numpy rng output casted to xp.array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the workaround (NumPy-like - rng-jax) for now, and see if we need a more detailed (JAX-like) solution in the future.
p[m] = True | ||
while True: | ||
ap = a[:, p] | ||
xp = x[p] | ||
sp = np.linalg.solve(ap.T @ ap, b @ ap) | ||
x_new = x[p] | ||
sp = xp.linalg.solve(ap.T @ ap, b @ ap) | ||
t = sp <= 0 | ||
if not np.any(t): | ||
if not xp.any(t): | ||
break | ||
alpha = -np.min(xp[t] / (xp[t] - sp[t])) | ||
alpha = -xp.min(x_new[t] / (x_new[t] - sp[t])) | ||
x[p] += alpha * (sp - xp) | ||
p[x <= 0] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #432 (comment)
if a.__array_namespace__() != b.__array_namespace__(): | ||
msg = "input arrays should belong to the same array library" | ||
raise ValueError(msg) | ||
|
||
xp = a.__array_namespace__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this into a generic helper? Seeing as all functions will need it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good. I'll open a new issue and take this up.
@paddyroddy I think I've now done this in #643 |
Agreed, sorry @Saransh-cpp |
WIP (not draft to test array API in the CI)
Has some caveats that I am working on, but opening this PR for visibility.
Closes: #405