By Olivier Grisel, ML Engineer at Probabl & core maintainer of scikit-learn
For over a decade, scikit-learn has served as the bedrock of machine learning, supporting the work of millions of data scientists worldwide and recently surpassing 4 billion downloads [1]. Scikit-learn was originally designed for a CPU-centric world, relying heavily on the foundational stack of NumPy, SciPy, and Cython. However, with the advent of new hardware, there are new opportunities to accelerate machine learning pipelines with scikit-learn.
Speeding up machine learning pipelines is significant for enterprises, where compute bottlenecks are not only a technical lag but also a barrier to operational agility. When model training takes hours instead of minutes, the time-to-insight stretches, delaying the impact of data science projects as a consequence. Even for smaller datasets, when model training takes minutes instead of seconds, interactive model development in notebooks stops being interactive, disrupting the quick iteration cycle of focussed data scientists and their productivity as a result.
In this post, I bring you up to speed on our efforts to adopt the Python array API standard in scikit-learn in order to tackle this problem and to facilitate hardware acceleration in data science workflows. An important point to emphasize is that this transition is not only a performance optimization; it is a fundamental re-engineering that allows data scientists to leverage scikit-learn’s 200+ estimators, while delegating performance-critical tasks to GPU-backed libraries like PyTorch, CuPy, and maybe soon JAX that unlock game-changing speed-ups for data scientists building complex machine learning pipelines.
Historically, library maintainers like us at scikit-learn faced a vendor lock-in challenge. If we wanted to support GPUs, we would have had to write specialized code paths for every specific backend (e.g., one for NumPy, one for CuPy, another for PyTorch). This led to fragmented codebases and maintenance overhead.
The Python array API standard [2] solves this by providing a unified specification for NumPy-like operations. It is a common language adopted by major array libraries. By targeting this specification, scikit-learn can remain "backend agnostic."
Core Concept: When an estimator is array API-compliant, it inspects the input data. If you pass a PyTorch tensor residing on an NVIDIA GPU, scikit-learn uses the array API to dispatch the underlying linear algebra to PyTorch’s GPU kernels. The computation happens on the device where the data lives.
Converting a library as vast as scikit-learn–which has over 200 estimators–is a significant undertaking. Indeed, whenever an estimator is converted, we also set up automated testing to ensure that it numerically behaves consistently across backends. This is a multi-year effort involving deep collaboration between Probabl, Quansight, NVIDIA, and the broader scientific Python community.
So far, approximately 25 estimators out of 200 are either partially, fully compatible or in the final stages of integration. Most metric functions (e.g. R2, log loss, Brier score) and tools such as cross-validation functions and the scoring API have been updated. Specific tests and continuous integration configuration has also been put in place to regularly monitor the correct execution of those components on a GPU and more test infrastructure work is in progress.
To be a bit more precise, let me explain some of the technical changes involved in converting from NumPy to the array API.
Before, the code would explicitly import NumPy (as “np”) perform linear algebra operations on NumPy arrays passed as input to the scikit-learn functions. Now, compliant functions accept any array API-compliant input without any explicit hard dependencies on those libraries: the underlying module is retrieved (as “xp”) by inspecting the input arguments. Subsequent linear algebra operations are therefore delegate to input-specific libraries without having to couple the source code explicitly to any of those array libraries.
In practice, not all array API compatible libraries are 100% compliant with the specification (yet) and importing array_api_compat is a pragmatic way to handle the transition. For instance, PyTorch implements some features from the spec under different names. So instead of retrieving the array namespace from PyTorch, we ask array_api_compat to get a standard compliant PyTorch wrapper. If the input array stem from a compliant library, array_api_compatc simply returns that module as is.
On top of this, array-api-extra brings extra benefits that go beyond the spec and enable support for other libraries with special design constraints, such as JAX.
The value-add for the data scientist: The significance of this work for the millions of data scientists around the world who use scikit-learn lies in the seamless scalability that has been unlocked. In the past, moving a scikit-learn pipeline to a GPU required a complete rewrite using different libraries. With the array API, this transition is possible. You can now tell scikit-learn to delegate compute intensive work to GPU-aware, array API-compliant libraries.
To illustrate the impact of this work, I measured the time it takes to fit and evaluate the following multistep polynomial regression pipeline:
poly_reg_torch_gpu = make_pipeline(
SplineTransformer(n_knots=5),
FunctionTransformer(partial(torch.asarray, device="cuda")),
Nystroem(kernel="poly", degree=2, n_components=300, random_state=0),
Ridge(solver="svd", alpha=1e-3),
)
cv_results_torch_gpu = cross_validate(
poly_reg_torch_gpu, X, y_torch_gpu, cv=5
)
In the above code, the SplineTransformer has not yet been updated to accept array API inputs while the other steps did. To upgrade this pipeline, we therefore insert a FunctionTransformer step to call torch.asarray(out, device="cuda") on the output of the first step before passing the resulting PyTorch GPU array to the Nystroem step and dramatically accelerate the last to steps by letting them operate on the CUDA device.
By offloading these steps to a GPU using the array API, I observed a 15x speed-up compared to traditional CPU execution.
Takeaway: Thanks to GPU acceleration, we can now tune the hyperparameters in a complex pipeline to get a very good model in the time it would take to run a single cross-validation on the Google Colab CPU. More importantly, the training speed is fast enough to avoid disrupting the model development flow of the data scientist interactively editing the Google Colab notebook.
I recently had the pleasure of joining NVIDIA experts Andy Terrel, Sergey Maydanov, Ashwin Srinath, and Leo Fang for a technical deep dive into the CUDA Python roadmap and the adoption of the Python array API in scikit-learn.
We discussed topics like strategies for making GPU-accelerated computing more seamless and accessible for Python developers and data scientists.
We had over 700 people tune in from all over the world. If you missed the live event, I encourage you to watch the replay to see the live demo and the array API in action.
By bridging the gap between the easy-to-use and familiar interface of Python libraries and the power of GPUs, we are lowering the barrier to entry for high-performance AI, making it a practical reality for enterprises of all sizes and skills.
Gaël Varoquaux, our CSO, and Yann Lechelle, our Executive President, will be at NVIDIA GTC 2026 in San José next week. Don't be a stranger; connect with them there!
On March 17 (3:00 PM – 3:40 PM PDT / 11:00 PM - 11:40 PM CET), Gaël will be speaking on the "Accelerating Open Science: Incorporating CUDA Into the SciPy Ecosystem" panel. Gaël will discuss adopting CUDA in scikit-learn without sacrificing usability, portability, or community values alongside Leo Fang, Ianna Osborne, Travis Oliphant, and Katrina Riehl.
On March 18 (5:00 AM – 5:50 AM PDT / 1:00 PM - 1:50 PM CET), Yann will be speaking on the “Europe’s AI Launchpad: Unlock Startup Growth Through Sovereign AI Infrastructure [S81898]” panel. Yann will discuss dynamic AI compute landscape as well as public and private compute options for startups alongside Cedric Auliac, Pierre-Antoine Beaudoin, and Sadaf Alam.
Our efforts to adopt the array API is the result of a massive team effort. I want to extend my gratitude to the maintainers and contributors from Quansight, NVIDIA, and the community of scikit-learn core contributors. The work performed by Probabl and Quansight on scikit-learn and SciPy is supported by the NASA ROSES grant 80NSSC25K7215 "Ensuring a fast and secure core for scientific Python." This support is vital for maintaining the health of the open source ecosystem that the world’s scientific and industrial infrastructure relies upon.
[1] https://clickpy.clickhouse.com/dashboard/scikit-learn Please note that PyPI downloads are a proxy for adoption and should be taken with a grain of salt; they are not the only way to download a python library, and they may not accurately convey usage.
[2] Python array API standard https://data-apis.org/array-api/latest/
[3] Enabling array API support in scikit-learn https://scikit-learn.org/stable/modules/array_api.html
[4] Colab notebook of the demo: https://colab.research.google.com/drive/1YrCt5iBPT6gnmp7geahRn_9OqCPrfoLb?usp=sharing