Deep Learning with Flax¶
Flax and Optax are libraries for deep learning, written in Jax and maintained by Google. Flax provides JAX components for composing models (e.g., neural network layers) while Optax focus on providing optimisers for training these models (e.g., various loss functions and optimisers such as stochastic gradient descent).
This notebook demonstrates how to use these libraries to compose and train a simple model on the MNIST dataset.
Imports and Configuration¶
from typing import Any, Callable, Dict, Tuple
import flax.linen as nn
import jax
import optax
import torchvision
from flax.training.train_state import TrainState
from jax import random, numpy as jnp
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
Get Dataset¶
We will lean on PyTorch's datasets and loaders and adapt them for JAX.
import numpy as np
def flatten_and_cast(x: Any) -> jnp.ndarray:
return jnp.ravel(jnp.array(x, dtype=jnp.float32))
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
train_data = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=flatten_and_cast
)
test_data = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=flatten_and_cast
)
training_data_loader = DataLoader(
dataset=train_data, batch_size=500, collate_fn=numpy_collate
)
test_data_loader = DataLoader(
dataset=test_data, batch_size=1000, collate_fn=numpy_collate
)
print(f"{len(train_data):,} instances of training data")
print(f"{len(test_data):,} instances of training data")
60,000 instances of training data 10,000 instances of training data
Inspect a single instance of training data.
data_instance, data_label = train_data[0]
print(f"label = {data_label}")
_ = plt.imshow(data_instance.reshape(28, 28), cmap="gray")
label = 5
Training a Classification Model¶
Start by defining the network that we want to train, which in this case is a neutal network with a single hidden layer.
class ClassifyMNIST(nn.Module):
n_hidden: int = 28 * 28
n_classes: int = 10
@nn.compact
def __call__(self, x):
x = nn.Dense(self.n_hidden, name="hidden_layer")(x)
x = nn.relu(x)
x = nn.Dense(self.n_classes, name="output_layer")(x)
return x
model = ClassifyMNIST()
model
ClassifyMNIST( # attributes n_hidden = 784 n_classes = 10 )
And then the training routine to use with it, which in this case is plain Stochastic Gradient Descent (SGD). JAX is a purely functional framework and this hs an impact on how training loops interact with models. In JAX the forward-pass of a model is a function of both the model parameters, we well as the inputs. During training, this means that paramters needs to be managed seperately from the model. Flax provides the TrainState
object to help assist with this.
def compute_metrics(training_state: TrainState, X: jnp.ndarray, y: jnp.ndarray):
"""Loss and accuracy calculations."""
logits = training_state.apply_fn(training_state.params, X)
loss = optax.softmax_cross_entropy(
logits=logits, labels=jax.nn.one_hot(y, 10)
).mean()
accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
metrics = {
"loss": loss.tolist(),
"accuracy": accuracy.tolist(),
}
return metrics
def train(loss_fn: Callable, n_epochs: int = 10, learning_rate: float = 0.01) -> optax.Params:
"""Archetypal Flax training loop is a function of a loss function."""
@jax.jit
def process_batch(
training_state: TrainState, X: jnp.ndarray, y: jnp.ndarray
) -> TrainState:
loss_grads = jax.grad(loss_fn)(training_state.params, X, y)
return training_state.apply_gradients(grads=loss_grads)
def process_epoch(
training_state: TrainState,
) -> Tuple[TrainState, Dict[str, float]]:
for X, y in training_data_loader:
training_state = process_batch(training_state, X, y)
metrics = compute_metrics(training_state, X, y)
return training_state, metrics
key1, key2 = random.split(random.PRNGKey(0))
training_state = TrainState.create(
apply_fn=model.apply,
params=model.init(key2, random.normal(key1, (28 * 28,))),
tx=optax.sgd(learning_rate=learning_rate),
)
for n in range(n_epochs):
training_state, metrics = process_epoch(training_state)
print(f"epoch={n}; metrics={metrics}")
return training_state.params
We can now apply the training algorithm to the model, which we do for 10 epochs and a learning rate of 0.01.
def classify_mnist_loss(params: optax.Params, X: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Loss as an explicit function of model parameter for gradient calulation."""
loss_values = optax.softmax_cross_entropy(
model.apply(params, X), jax.nn.one_hot(y, 10)
)
return loss_values.mean()
trained_params = train(classify_mnist_loss)
epoch=0; metrics={'loss': 0.2294992208480835, 'accuracy': 0.9340000152587891} epoch=1; metrics={'loss': 0.17606227099895477, 'accuracy': 0.9620000720024109} epoch=2; metrics={'loss': 0.15006238222122192, 'accuracy': 0.9700000286102295} epoch=3; metrics={'loss': 0.1338450163602829, 'accuracy': 0.9720000624656677} epoch=4; metrics={'loss': 0.11482236534357071, 'accuracy': 0.9760000705718994} epoch=5; metrics={'loss': 0.10640661418437958, 'accuracy': 0.9780000448226929} epoch=6; metrics={'loss': 0.09827155619859695, 'accuracy': 0.9780000448226929} epoch=7; metrics={'loss': 0.09185998886823654, 'accuracy': 0.9780000448226929} epoch=8; metrics={'loss': 0.08489469438791275, 'accuracy': 0.9800000190734863} epoch=9; metrics={'loss': 0.07957303524017334, 'accuracy': 0.9820000529289246}
We can use the trained model to compute classification accuracy on the test dataset. Note once again that the model.apply
method is a function of both the trained parameters, as well as the input data.
def predict(X: jnp.ndarray) -> float:
logit = model.apply(trained_params, X)
return jnp.argmax(logit, -1).tolist()
correct = 0
for X, y in test_data_loader:
y_pred = predict(X)
correct += (y_pred == y).sum()
accuracy = correct / len(test_data)
print(f"test_data accuracy = {accuracy}")
test_data accuracy = 0.9462
We can also make inferences for individual instances of data.
predict(data_instance)
5
Comparison with PyTorch¶
The same example used in this notebook has also been implemented with PyTorch. This enables a direct comparison between the two frameworks in terms of code verbosity and style. Initial impressions are that the training loop requires more lines of code in JAX, but it is easier to reason about how the model parameters are updated (as they must be passed around explicitly). Similarly, model definition is slightly more compact in JAX.