Introduction to JAX¶
JAX is a framework for accelerated scientific computing that provides an alternative implementation of the NumPy API for linear algebra computation with added auto-differentiation and Just-in-Time (JIT) compilation. Unlike similar frameworks - e.g., PyTorch or TensorFlow - it works within a purely functional programming paradigm.
- How to manipulate tensors - i.e., JAX as an alternative to NumPy.
- How to use auto-differentiation and minimise arbitrary functions.
- How to build and train ML models from first principles - linear regression.
- How to build and train a deep learning model for image classification using Flax and Optax.
Running the Demo¶
This demo spans several Jupyter notebook:
Make sure you have the necessary Python package requirements installed into a Jupyter kernel for it to run successfully.