Multi-Layer Perceptron from Scratch with JAX

Multi-Layer Perceptron from Scratch with JAX

Jiarui Xu | @May 19, 2023

This is the first article of “Fun with Language Models” series. It starts with Simple MLP from scratch using JAX. I use JAX for the series because:

  • It’s gaining more popularity within the ML community
  • The JAX implementation in a blog provides a good content revealing the inner workings without being overly specific about the optimization aspects such as derivation, batching, and GPU.
  • The most important of all: I personally want to learn more JAX (Feynman technique at play).

Another keyword here is "scratch". Einstein said, "If you can't explain it simply, you don't understand it well enough." Novel products and creations start from simple ideas and only need simple explanations. To refresh my memory of classes I took at school, I think implementing from Scratch can help me understand better. It is dangerous to become accustomed to using blackbox libraries, as they can cause you to forget the fundamentals and the "big picture".

The history of Language Model starts from n-gram, which is usually the first intro part of Natural Language Processing courses. However, I want to progress faster to the Transformers chapter so I decide to revisit n-gram later. Starting with MLP, I will then implement RNN (Recurrent Neural Network) from scratch and build a small language model with RNN. Then we will come back to MLP to create Word2Vec. Before diving into the recent era of Transformers, I will implement GloVe from scratch as well.

Multi-Layer Perceptron (MLP)

A Multi-Layer Perceptron (MLP) is a type of artificial neural network that is widely used in machine learning. It is composed of multiple layers of nodes, where each node is connected to the nodes in the previous and following layers. The MLP is capable of learning complex non-linear relationships between inputs and outputs, making it a popular choice for various tasks such as classification, regression, and pattern recognition.

The rest of the blog assumes a basica understanding of how MLP works. StatQuest has a pretty simple math demo here:

Some additional resources for backpropagation and computational graph:

Some Personal History

I first learned Artificial Neural Network at HKUST during 2014 summer programme. Other Computer Science courses were so fast to fill and this is was the only one available that I could take to transfer the credits back to UIUC. Soon after, Machine Learning became a sought-after course due to its increasing popularity and I believe it’s now the fastest-filling course for many CS departments.

The first part of the course at HKUST is building a neural network to recognize MNIST images. And this is how we built MLP back then:

image

Nowadays, there are numerous libraries and packages to build MLPs and train them more efficiently, with numerous blogs and videos explaining the process. The UI above looks outdated. However, as a visual learner, this was a great visual NN builder that can give you a high-level view of what’s happening. I was really impressed by the nodes and edges that could be optimized to recognize digits, magically.

Let’s build MLP with JAX

JAX is a Python library that provides a high-performance implementation of automatic differentiation and numerical computation. It is designed to be fast, flexible, and composable, making it an excellent choice for building machine learning models. I really liked it during Revisit Cox Proportional Hazards with JAX so I am writing MLP with it again.

The ToyMLP class code is attached here and on GitHub https://github.com/jxucoder/fun_with_language_models/blob/master/ironhide/models/mlp.py

The toy MLP is really simple and doesn’t have regulaization in the cost function.

The code is mostly adapted from JAX’s notebook https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb with some additional changes: custom pytree class implementation, nn module, pick up new features like decorators, and jaxopt.

@register_pytree_node_class
class ToyMLP:
    def __init__(self, parameters=None, layer_sizes=None):
        self.parameters = parameters
        self.layer_sizes = layer_sizes

        if (not self.parameters) and self.layer_sizes:
            self.initialize_params()

    def initialize_params(self):
        initializer = jax.nn.initializers.normal(0.01)
        self.parameters = [
            [initializer(jax.random.PRNGKey(42), (n, m), dtype=jnp.float32),
             initializer(jax.random.PRNGKey(42), (n,), dtype=jnp.float32)]
            for m, n in zip(self.layer_sizes[:-1], self.layer_sizes[1:])
        ]

    @partial(vmap, in_axes=(None, None, 0))
    def forward(self, parameters, x):
        activations = x
        for w, b in parameters[:-1]:
            outputs = jnp.dot(w, activations) + b
            activations = relu(outputs)

        final_w, final_b = parameters[-1]
        logits = jnp.dot(final_w, activations) + final_b
        return logits - logsumexp(logits)

    @jit
    def cost_func(self, params, x, y):
        preds = self.forward(params, x)
        return -jnp.mean(preds * y)

    @jit
    def update(self, x, y, step_size):
        grads = grad(self.cost_func)(self.parameters, x, y)
        return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(self.parameters, grads)]

    def predict(self, x):
        return jnp.argmax(self.forward(self.parameters, x), axis=1)

    def fit(self, train_x, train_y, num_epochs=50, step_size=0.05, use_jaxopt=False):
        if use_jaxopt:
            solver = jaxopt.GradientDescent(fun=self.cost_func, stepsize=step_size)
            params, state = solver.run(self.parameters, x=train_x, y=train_y)
            self.parameters = params
        else:
            for epoch in range(num_epochs):
                for x, y in batchify(train_x, train_y):
                    self.parameters = self.update(x, y, step_size)

    def tree_flatten(self):
        children = (self.parameters, self.layer_sizes)  # arrays / dynamic values
        aux_data = {}  # static values
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)

Step-by-step Details

This demo will use MNIST as the learning task: Given an MNIST image, predict a digit from 0 to 9.

image

1. Architecture

To implement an MLP with JAX, we first define the architecture of the network by specifying the number of layers and the number of nodes in each layer:

An MNIST image has 28 x 28 pixels, which lead to the input layer of size 784 (one input node for every pixel). We have 10 digits, which lead the output layer of size 10. We define two hidden layers of size 256 and 128. The layer sizes are [784,256,128,10][784, 256, 128, 10]. The hidden layers are chosen arbitrarily. Ths size is too big to visualize in thie blog so I show a visualization (credit: https://alexlenail.me/NN-SVG/) of layer sizes [32,24,16,10][32, 24, 16, 10].

image

2. Forward Propagation

Forward propagation is the process of passing input data through a neural network to obtain a prediction or output. In the context of a multi-layer perceptron (MLP), forward propagation involves computing the output of each layer of the network from the output of the previous layer, until the final output of the network is obtained.

We will use the Rectified Linear Unit (ReLU) as activation function. Some other common choices are sigmoid, tanh, and softplus [2].

image

ReLU is defined as: f(x)=max(0,x)f(x) = \max(0, x).

from jax.nn import relu

activations = relu(wx + b)

jax.nn is a module in the JAX library that provides a collection of neural network primitives such as activation functions, convolutional layers, and dense layers. These primitives are designed to be fast, composable, and differentiable, making them well-suited for building deep learning models. In the given code, jax.nn.relu is used as the activation function for the MLP.

We then initialize the weights and biases of the network using random values.

jax.nn.initializers is a module in the JAX library that provides a collection of functions for initializing the weights and biases of neural networks. These functions are designed to provide a good starting point for the optimization process and can help to prevent issues such as vanishing or exploding gradients. In the given code, jax.nn.initializers.normal is used to initialize the weights and biases with random values drawn from a normal distribution with mean 0 and standard deviation 0.01.

initializer = jax.nn.initializers.normal(0.01)
self.parameters = [
		[initializer(jax.random.PRNGKey(42), (n, m), dtype=jnp.float32),
		initializer(jax.random.PRNGKey(42), (n,), dtype=jnp.float32)]
		for m, n in zip(self.layer_sizes[:-1], self.layer_sizes[1:])
]

Let’s now demonstate how we calculate the model’s prediction: a forward propagation.

@partial(vmap, in_axes=(None, None, 0))
def forward(self, parameters, x):
		activations = x
		for w, b in parameters[:-1]:
				outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = parameters[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

The code has two interesting details:

  • vmap() for automatic vectorization or batching.

vmap() is a JAX function for automatic vectorization or batching. It takes a function that performs an operation on individual examples and returns a new function that can apply the operation to multiple examples in parallel. This can be useful for speeding up computations on large datasets or for parallelizing computations on GPUs or TPUs.

  • logsumexp() for numerical stability when computing the softmax function.

You can read more about the log-sum-exp trick at [3].

Now, finally, with the neural network’s output, we can predict a digit out of 10.

The predict function takes an input array x and returns an array of predicted labels for each example in x. First, it calls the forward function to compute the output of the MLP for each example in x. Then, it applies the argmax function along the second axis of the output array to obtain the index of the maximum value for each example, which corresponds to the predicted label.

def predict(self, x):
		return jnp.argmax(self.forward(self.parameters, x), axis=1)

3. Loss

The most commonly used loss functions for classification tasks are cross-entropy loss and mean squared error for regression tasks. In the given code, the loss function is defined as the negative mean of the dot product of the model's prediction and the target output. It is implemented as follows:

@jit
def cost_func(self, params, x, y):
    preds = self.forward(params, x)
    return -jnp.mean(preds * y)

Here, preds is the model's prediction, and y is the target output. The * operator is used to perform element-wise multiplication between preds and y. The resulting array is then summed and divided by the number of examples to obtain the mean loss. The negative sign is used to convert the cost function into a minimization problem, which is required by most optimization algorithms.

4. Backward Propagation

We can use JAX to calculate gradients of the loss function with respect to the weights and biases, which is a key step in training the MLP.

@jit
def update(self, x, y, step_size):
    grads = grad(self.cost_func)(self.parameters, x, y)
    return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(self.parameters, grads)]

The update method is used to update the weights and biases of the MLP during training. It takes as input the input data x, the target output y, and the step size step_size which controls the size of the weight and bias updates.

The method first calculates the gradients of the loss function with respect to the weights and biases using the grad function provided by Jax. The gradients are then used to update the weights and biases according to the following rule:

w_new = w_old - step_size * dw
b_new = b_old - step_size * db

Here, w_old and b_old are the old values of the weights and biases, and dw and db are the gradients of the loss function with respect to the weights and biases, respectively. The step size step_size controls the size of the weight and bias updates and is typically chosen to be a small positive value.

The update method returns a list of tuples containing the updated weights and biases for each layer of the MLP.

5. JAXopt: Hardware accelerated, batchable and differentiable optimizers for JAX

Here comes JAXopt: Hardware accelerated, batchable and differentiable optimizers in JAX.

  • Hardware accelerated: the implementations run on GPU and TPU, in addition to CPU.
  • Batchable: multiple instances of the same optimization problem can be automatically vectorized using JAX's vmap.
  • Differentiable: optimization problem solutions can be differentiated with respect to their inputs either implicitly or via autodiff of unrolled algorithm iterations.

JAXopt provides various optimization algorithms such as stochastic gradient descent (SGD) that can be used to update the weights and biases and minimize the loss function during training.

solver = jaxopt.GradientDescent(fun=self.cost_func, stepsize=step_size)
params, state = solver.run(self.parameters, x=train_x, y=train_y)
self.parameters = params

6. JAX for class

The original JAX notebook implements the NN using global variables. One fun part of JAX is pytree to enable JAX tracing for class methods: https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods. The recommended method is:

Strategy 3: Making CustomClass a PyTree

So long as your tree_flatten and tree_unflatten functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations.

Hence we have the following implemenations:

def tree_flatten(self):
		children = (self.parameters, self.layer_sizes)  # arrays / dynamic values
		aux_data = {}  # static values
    return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
		return cls(*children, **aux_data)

Don’t forget the @register_pytree_node_class decorator.

JAX has a great pytree 101 tutorial here: https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb#scrollTo=Wh6BApZ9lrR1

Evaluation

Finally, we can evaluate the performance of the MLP on a test dataset to see how well it generalizes to new data.

Without training (right after initialization):

image

Training without JAXopt:

image

Training with JAXopt:

image

Spot-checking on the test example of digit 2:

image

The trained model can correctly predict 2:

image

References

[1] ChatGPT and Notion AI has generated helpful text for the writing of the blog

[2] CMU Deep Learning lecture-2 https://www.cs.cmu.edu/~bhiksha/courses/deeplearning/Fall.2019/www.f19/document/lecture/lecture-2.pdf

[3] The Log-Sum-Exp Trick by Gregory Gundersen https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/