JAX (software)
In-game article clicks load inline without leaving the challenge.
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.
It is described as bringing together a modified version of the automatic differentiation system autograd and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. The primary features of JAX are:
- Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
- Built-in Just-In-Time (JIT) compilation via OpenXLA, an open-source machine learning compiler ecosystem.
- Efficient evaluation of gradients via its automatic differentiation transformations.
- Automatic vectorization to efficiently map functions over arrays representing batches of inputs.
Libraries using Jax
- Flax
- Equinox
- Optax
See also
- NumPy
- TensorFlow
- PyTorch
- CUDA
- Accelerated Linear Algebra
- Comparison of machine learning software
- List of numerical libraries
External links
- Documentationː
- Colab (Jupyter/iPython) Quickstart Guideː
- TensorFlow's XLAː (Accelerated Linear Algebra)
- YouTube TensorFlow Channel "Intro to JAX: Accelerating Machine Learning research":
- Original paperː