RNN Language Modeling from scatch with JAX

RNN Language Modeling from scatch with JAX

In the beginning, I was trying to write an RNN class following MLP class example but the custom pytree makes the code a little ugly. However, I discovered equinox and it just converts your custom class to pytree for you. The current RNN class implementation learns from amazing previous works by equinox and @phinate’s demo, and uses architecture config from the pytorch implementation of 224N class. The notebook experiment is performed on Penn Treebank https://paperswithcode.com/dataset/penn-treebank similar to https://docs.chainer.org/en/stable/examples/ptb.html.

The “from-scratch” code is shared below:

I recommending viewing this amazing MIT lecture by Ava Soleimany. She has given a great overview of the development of sequence modeling, including recent innovations like Transformers:

What is Language Model?

model that estimates the probability of a token or sequence of tokens occurring in a longer sequence of tokens. (Google ML Glossory)
image

Figure 1. example prediction by 224N lecture slides)

It is a fundamental component of many natural language processing applications, such as speech recognition, machine translation, and text generation. The goal of a language model is to capture the underlying structure and patterns of natural language, and use this knowledge to generate coherent and grammatical text.

The simplest type of language model is n-gram model, which is statiscal technique without ML.

Interesting. But explain it to me in ML language.

In ML, a language model is typically trained on a large corpus of text data, such as books, articles, or web pages. During training, the language model analyzes the data to learn the probabilities of different words and their dependencies on each other. It builds a probabilistic model that estimates the likelihood of observing a specific word given a sequence of contextual words.

image

Figure 2. example MLP prediction

Some models’ training involves predicting the next word in a sentence based on the preceding context. It does so by learning the statistical patterns and relationships between words in the training data. The model updates its internal parameters, such as weights and biases, through optimization algorithms like gradient descent to minimize the difference between its predicted probabilities and the true probabilities observed in the training data.

What ML model can we use? RNN.

Theoretically, most ML models can be used. We can even use xgboost if there is nothing better to do 🙂. Recurrent Neural Network is a populat choice and we will explain why in the next section. RNN is actually a family of neural network archetectures and here we use the simplest RNN to demonstate the idea.

Let’s not worry about the model form for now. The model form is easier to understand with JAX code. We first clarify the input and target of an example RNN model:

  • input: a sequence of words or tokens
  • target: predict the next word or token in the sequence based on the preceding context.
image

Figure 3: example RNN prediction

During training, the RNN language model is presented with pairs of input sequences and corresponding target sequences. The input sequence consists of a sequence of words, and the target sequence is the same sequence shifted by one word, representing the next word that the model needs to predict.

The RNN model processes the input sequence one word at a time, while maintaining an internal hidden state that captures the information from the previous words. This hidden state acts as the memory of the model, allowing it to capture and remember the context as it progresses through the sequence.

What makes RNN good?

Why don’t we just train MLP as the Figure 2 suggests? RNN has the following advantages compared to MLP:

  • Sequential processing: RNNs handle sequential data and capture dependencies.
  • Variable-length inputs: RNNs can process sequences of varying lengths.
  • Temporal connections: RNNs capture temporal dependencies between words.
  • Parameter sharing: RNNs share parameters across time steps, enabling generalization.
  • Flexibility in input representation: RNNs can handle various input types (e.g., word embeddings, character-level encodings).
  • Text generation: RNNs are effective for generating coherent and contextually relevant text.

MLPs lack these sequential processing and contextual modeling capabilities.

RNN architecture overview

This section gives a brief introduction to the RNN architecture, its key parameters, and the forward propagation process of the language model. It also highlights the advantages of RNNs over MLPs for language modeling, such as handling sequential and variable-length inputs, temporal connections, and parameter sharing. Lastly, it briefly touches on the backward propagation process and optimization techniques used for training the RNN.

Parameters Preview

The whole point of training an RNN is to learn these parameters:

  • WembeddingW_{embedding}: Embedding matrix used to convert the input symbols into their corresponding embeddings. ML model can only recognize numeric vectors, not strings.
  • WhhW_{hh}: Weight matrix for previous hidden state to hidden state connections
  • WxhW_{xh}: Weight matrix for input embedding to hidden state connections
  • WhyW_{hy}: Weight matrix for hidden state to output connections
  • bhb_h: Bias vector for hidden state
  • byb_y: Bias vector for output

Forward Propagation

Embedding lookup

The input is actually a sequence of one-hot vectors. Assuming we have vocab_size unique tokens in the corpus, and each token has its unique index, then the one-hot vector for the word the can be a zero vector of size vocab_size with a 1 at its index position. Example one-hot vector for the assuming we have 10 unique words and the index for the is 2.

[0,0,1,0,0,0,0,0,0,0][0, 0, 1, 0,0,0,0,0,0,0]

The one-hot vectors are useful because we can use the embedding layer to reduce the dimension and create semantic relationships among words (e.g. “dog” and “cat” are close to each other in the embedding space). We use matrix multiplication with an embedding matrix WembeddingW_{embedding} to retrieve token’s embedding.

  • WembeddingW_{embedding} is represented as a matrix with dimensions (embed_size x vocab_size), where embed_size represents the dimensionality of the embeddings and vocab_size represents the number of unique tokens in the vocabulary.
  • Recall that one-hot vector xohex_{ohe} is represented as a row vector with dimension (vocab_size),

The multiplication product of WembeddingW_{embedding} and xohex_{ohe} is the embedding vectors for input tokens. Thanks to JAX’s batch support, this logic is as easy as one line of code:

make_sequence_embedding = jax.vmap(lambda w_embeds, one_hot_word: w_embeds @ one_hot_word, in_axes=(None, 0))

Transformations and Activations

After we get the embedding (inputs), it’s pretty straightforward if we are in the MLP world: just multiply the inputs and weights, run through activations, until the last output layer. In the RNN world, it’s a bit more complicated, due to “Recurrent”.

We use tanh as the hidden layer’s activiation function.

image

Additional notes on the notation below:

  • hth_t: Hidden state at time step tt
  • xtx_t: Input at time step tt
image

Let’s calculate the forward pass step by step:

  • At t=0 (first token)
h0=tanh(Wxhx0+Whhh1+bh)output0=softmax(Whyh0+by)h_0 = tanh(W_{xh} * x_0 + W_{hh} * h_{-1} + b_h)\\ output_0 = softmax(W_{hy} * h_0 + b_y)

Typically we use a zero vector for the “previous hidden layer” h1h_{-1} for the firs token, which makes the forward pass look like an MLP forward pass:

h0=tanh(Wxhx0+bh)output0=softmax(Whyh0+by)h_0 = tanh(W_{xh} * x_0 + b_h)\\ output_0 = softmax(W_{hy} * h_0 + b_y)
  • At t=1 (second token)
h1=tanh(Wxhx1+Whhh0+bh)output1=softmax(Whyh1+by)h_1 = tanh(W_{xh} * x_1 + W_{hh} * h_{0} + b_h)\\ output_1 = softmax(W_{hy} * h_1 + b_y)

This is where things get more interesting. Not only we look the current token, we also look at the hidden state from the previous token. The idea that the hidden state carries historical information that can be used to predict the next token. The hidden state is carried over.

  • At t=m (m+1 th token)
hm=tanh(Wxhxm+Whhhm1+bh)outputm=softmax(Whyhm+by)h_m = tanh(W_{xh} * x_m + W_{hh} * h_{m-1} + b_h)\\ output_m = softmax(W_{hy} * h_m + b_y)

Let’s wrap it in JAX code:

def __call__(self, sent):
    embedding = make_sequence_embedding(self.w_embedding, sent)
    initial_h = jnp.zeros((self.hidden_size,))

    def f(h_prev, word_embedding):
        h_curr = jax.nn.tanh(
            self.w_h_hprev @ h_prev 
						+ self.w_h_embedding @ word_embedding 
						+ self.h_bias
        )
        curr_outputs = self.w_output_h @ h_curr + self.output_bias
        return h_curr, curr_outputs

    _, out = jax.lax.scan(f, initial_h, embedding)

    return jnp.array(out)

jax.lax.scan is the magic function that handles the forward propagation throught t. It carries h_t over time. Read more on it here: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html

Loss

Recall that each t has their own target (next token) so we do get cross entroy loss for each output . Thanks to JAX, the batch version is pretty simple too:

cost_func = jax.vmap(optax.softmax_cross_entropy, in_axes=(0, 0))

Backward Propagation

Well, this is why we love computation graph and JAX. The Backward Propagation, if implemented with numpy, will be fairly complicated. With JAX and optax, it just “happens”:

opt = optax.chain(
    optax.clip(1),
    optax.adamw(learning_rate=lr),
)

We use clip and adamw to avoid gradient explosion.

Here is what happens in the fit call

loss, grads = filter_jit_batched_loss_and_gradient(self, batched_inputs, batched_targets)
avg_grads = jax.tree_map(lambda g: g.mean(axis=0), grads)

updates, opt_state = opt.update(avg_grads, opt_state, params=self)
self = eqx.apply_updates(self, updates)

Does it work?

We can the training loss trending down as more batches come through it. We can improve the initialization and tune hyperparameters to improve the optimization process.

image

The loss reduction curve looks a bit suspicous since it’s not very smooth. I could probably pick better hyperparameters or initial values. I’ll leave it for future work.

It’s far from ending

So far we have only covered the fitting of the network, but there are many other interesting things:

  • What does the the embedding matrix look like now? we can take a look at a few words and check their cosine similarities.
  • How can we use the language model to generate text?
  • What are the probelms of RNN?
  • What are the evaluation metrics?

More on them later 🙂