Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions doc/sources/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ in many cases they are.

.. warning::
If array API inputs are passed to an estimator's ``.fit()``, subsequent data passed to methods such as
``.predict()`` or ``.score()`` of the fitted model might be of a different class than the ``X``/``y`` passed to
``.fit()``, but **it must reside on the same device** - meaning: a model that was fitted with GPU arrays cannot
make predictions on CPU arrays, and a model fitted with CPU array API inputs cannot make predictions on GPU
arrays, even if they are of the same class. Attempting to pass data on the wrong device might lead to
process-wide crashes.
``.predict()`` or ``.score()`` of the fitted model **must reside on the same device** - meaning: a model that
was fitted with GPU arrays cannot make predictions on CPU arrays, and a model fitted with CPU array API inputs
cannot make predictions on GPU arrays, even if they are of the same class. Attempting to pass data on the
wrong device might lead to process-wide crashes.

.. note::
An estimator fitted to array API inputs should only be passed objects of the same class that was passed to
``.fit()`` in subsequent calls to ``.predict()``, ``.score()``, and similar. In some cases, it might be
possible to pass a different class at prediction time without errors (particularly when fitting on CPU only),
but this is generally not supported and users should not rely on these interchanges working reliably.

.. note::
The ``target_offload`` option in config contexts and settings is not intended to work with array API
Expand Down Expand Up @@ -145,20 +150,6 @@ GPU operations on GPU arrays
pred = model.predict(X[:5])
assert isinstance(pred, torch.Tensor)

# Fitted models can be passed array API inputs of a different class
# than the training data, as long as their data resides in the same
# device. This now fits a model using a non-NumPy class whose data is on CPU.
X_cpu = torch.tensor(X_np, device="cpu")
y_cpu = torch.tensor(y_np, device="cpu")
model_cpu = LinearRegression()
with config_context(array_api_dispatch=True):
model_cpu.fit(X_cpu, y_cpu)
pred_torch = model_cpu.predict(X_cpu[:5])
pred_np = model_cpu.predict(X_cpu[:5].numpy())
assert isinstance(pred_torch, X_cpu.__class__)
assert isinstance(pred_np, np.ndarray)
assert pred_torch.__class__ != pred_np.__class__

.. tab:: With DPNP arrays
.. code-block:: python

Expand Down Expand Up @@ -193,20 +184,6 @@ GPU operations on GPU arrays
pred = model.predict(X[:5])
assert isinstance(pred, X.__class__)

# Fitted models can be passed array API inputs of a different class
# than the training data, as long as their data resides in the same
# device. This now fits a model using a non-NumPy class whose data is on CPU.
X_cpu = dpnp.array(X_np, device="cpu")
y_cpu = dpnp.array(y_np, device="cpu")
model_cpu = LinearRegression()
with config_context(array_api_dispatch=True):
model_cpu.fit(X_cpu, y_cpu)
pred_dpnp = model_cpu.predict(X_cpu[:5])
pred_np = model_cpu.predict(X_cpu[:5].asnumpy())
assert isinstance(pred_dpnp, X_cpu.__class__)
assert isinstance(pred_np, np.ndarray)
assert pred_dpnp.__class__ != pred_np.__class__


``array-api-strict``
--------------------
Expand Down
Loading