diff --git a/doc/sources/array_api.rst b/doc/sources/array_api.rst index f493878b36..54db180211 100644 --- a/doc/sources/array_api.rst +++ b/doc/sources/array_api.rst @@ -68,11 +68,16 @@ in many cases they are. .. warning:: If array API inputs are passed to an estimator's ``.fit()``, subsequent data passed to methods such as - ``.predict()`` or ``.score()`` of the fitted model might be of a different class than the ``X``/``y`` passed to - ``.fit()``, but **it must reside on the same device** - meaning: a model that was fitted with GPU arrays cannot - make predictions on CPU arrays, and a model fitted with CPU array API inputs cannot make predictions on GPU - arrays, even if they are of the same class. Attempting to pass data on the wrong device might lead to - process-wide crashes. + ``.predict()`` or ``.score()`` of the fitted model **must reside on the same device** - meaning: a model that + was fitted with GPU arrays cannot make predictions on CPU arrays, and a model fitted with CPU array API inputs + cannot make predictions on GPU arrays, even if they are of the same class. Attempting to pass data on the + wrong device might lead to process-wide crashes. + +.. note:: + An estimator fitted to array API inputs should only be passed objects of the same class that was passed to + ``.fit()`` in subsequent calls to ``.predict()``, ``.score()``, and similar. In some cases, it might be + possible to pass a different class at prediction time without errors (particularly when fitting on CPU only), + but this is generally not supported and users should not rely on these interchanges working reliably. .. note:: The ``target_offload`` option in config contexts and settings is not intended to work with array API @@ -145,20 +150,6 @@ GPU operations on GPU arrays pred = model.predict(X[:5]) assert isinstance(pred, torch.Tensor) - # Fitted models can be passed array API inputs of a different class - # than the training data, as long as their data resides in the same - # device. This now fits a model using a non-NumPy class whose data is on CPU. - X_cpu = torch.tensor(X_np, device="cpu") - y_cpu = torch.tensor(y_np, device="cpu") - model_cpu = LinearRegression() - with config_context(array_api_dispatch=True): - model_cpu.fit(X_cpu, y_cpu) - pred_torch = model_cpu.predict(X_cpu[:5]) - pred_np = model_cpu.predict(X_cpu[:5].numpy()) - assert isinstance(pred_torch, X_cpu.__class__) - assert isinstance(pred_np, np.ndarray) - assert pred_torch.__class__ != pred_np.__class__ - .. tab:: With DPNP arrays .. code-block:: python @@ -193,20 +184,6 @@ GPU operations on GPU arrays pred = model.predict(X[:5]) assert isinstance(pred, X.__class__) - # Fitted models can be passed array API inputs of a different class - # than the training data, as long as their data resides in the same - # device. This now fits a model using a non-NumPy class whose data is on CPU. - X_cpu = dpnp.array(X_np, device="cpu") - y_cpu = dpnp.array(y_np, device="cpu") - model_cpu = LinearRegression() - with config_context(array_api_dispatch=True): - model_cpu.fit(X_cpu, y_cpu) - pred_dpnp = model_cpu.predict(X_cpu[:5]) - pred_np = model_cpu.predict(X_cpu[:5].asnumpy()) - assert isinstance(pred_dpnp, X_cpu.__class__) - assert isinstance(pred_np, np.ndarray) - assert pred_dpnp.__class__ != pred_np.__class__ - ``array-api-strict`` --------------------