# Introduction to JAX¶

JAX is a framework for accelerated scientific computing that provides an alternative implementation of the NumPy API for linear algebra computation with added auto-differentiation and Just-in-Time (JIT) compilation. Unlike similar frameworks - e.g., PyTorch or TensorFlow - it works within a purely functional programming paradigm.

The Flax and Optax packages extend JAX's capabilities to cover the easy definition and training of deep learning models.

## Demo Objectives¶

- How to manipulate tensors - i.e., JAX as an alternative to NumPy.
- How to use auto-differentiation and minimise arbitrary functions.
- How to build and train ML models from first principles - linear regression.
- How to build and train a deep learning model for image classification using Flax and Optax.

## Running the Demo¶

This demo spans several Jupyter notebook:

`demos/jax/introduction_to_jax.ipynb`

.`demos/jax/linear_regression.ipynb`

.`demos/jax/mnist_with_flax_and_optax.ipynb`

.

Make sure you have the necessary Python package requirements installed into a Jupyter kernel for it to run successfully.