In the beginning, I was trying to write an RNN class following MLP class example but the custom pytree makes the code a little ugly. However, I discovered equinox and it just converts your custom class to pytree for you. The current RNN class implementation learns from amazing previous works by equinox and @phinate’s demo, and uses architecture config from the pytorch implementation of 224N class. The notebook experiment is performed on Penn Treebank https://paperswithcode.com/dataset/penn-treebank similar to https://docs.chainer.org/en/stable/examples/ptb.html.
The “from-scratch” code is shared below:
I recommending viewing this amazing MIT lecture by Ava Soleimany. She has given a great overview of the development of sequence modeling, including recent innovations like Transformers:
What is Language Model?
A model that estimates the probability of a token or sequence of tokens occurring in a longer sequence of tokens. (Google ML Glossory)
Figure 1. example prediction by 224N lecture slides)
It is a fundamental component of many natural language processing applications, such as speech recognition, machine translation, and text generation. The goal of a language model is to capture the underlying structure and patterns of natural language, and use this knowledge to generate coherent and grammatical text.
The simplest type of language model is n-gram model, which is statiscal technique without ML.
Interesting. But explain it to me in ML language.
In ML, a language model is typically trained on a large corpus of text data, such as books, articles, or web pages. During training, the language model analyzes the data to learn the probabilities of different words and their dependencies on each other. It builds a probabilistic model that estimates the likelihood of observing a specific word given a sequence of contextual words.
Figure 2. example MLP prediction
Some models’ training involves predicting the next word in a sentence based on the preceding context. It does so by learning the statistical patterns and relationships between words in the training data. The model updates its internal parameters, such as weights and biases, through optimization algorithms like gradient descent to minimize the difference between its predicted probabilities and the true probabilities observed in the training data.
What ML model can we use? RNN.
Theoretically, most ML models can be used. We can even use xgboost if there is nothing better to do 🙂. Recurrent Neural Network is a populat choice and we will explain why in the next section. RNN is actually a family of neural network archetectures and here we use the simplest RNN to demonstate the idea.
Let’s not worry about the model form for now. The model form is easier to understand with JAX code. We first clarify the input and target of an example RNN model:
- input: a sequence of words or tokens
- target: predict the next word or token in the sequence based on the preceding context.
Figure 3: example RNN prediction
During training, the RNN language model is presented with pairs of input sequences and corresponding target sequences. The input sequence consists of a sequence of words, and the target sequence is the same sequence shifted by one word, representing the next word that the model needs to predict.
The RNN model processes the input sequence one word at a time, while maintaining an internal hidden state that captures the information from the previous words. This hidden state acts as the memory of the model, allowing it to capture and remember the context as it progresses through the sequence.
What makes RNN good?
Why don’t we just train MLP as the Figure 2 suggests? RNN has the following advantages compared to MLP:
- Sequential processing: RNNs handle sequential data and capture dependencies.
- Variable-length inputs: RNNs can process sequences of varying lengths.
- Temporal connections: RNNs capture temporal dependencies between words.
- Parameter sharing: RNNs share parameters across time steps, enabling generalization.
- Flexibility in input representation: RNNs can handle various input types (e.g., word embeddings, character-level encodings).
- Text generation: RNNs are effective for generating coherent and contextually relevant text.
MLPs lack these sequential processing and contextual modeling capabilities.
RNN architecture overview
This section gives a brief introduction to the RNN architecture, its key parameters, and the forward propagation process of the language model. It also highlights the advantages of RNNs over MLPs for language modeling, such as handling sequential and variable-length inputs, temporal connections, and parameter sharing. Lastly, it briefly touches on the backward propagation process and optimization techniques used for training the RNN.
Parameters Preview
The whole point of training an RNN is to learn these parameters:
- : Embedding matrix used to convert the input symbols into their corresponding embeddings. ML model can only recognize numeric vectors, not strings.
- : Weight matrix for previous hidden state to hidden state connections
- : Weight matrix for input embedding to hidden state connections
- : Weight matrix for hidden state to output connections
- : Bias vector for hidden state
- : Bias vector for output
Forward Propagation
Embedding lookup
The input is actually a sequence of one-hot vectors. Assuming we have vocab_size
unique tokens in the corpus, and each token has its unique index, then the one-hot vector for the word the
can be a zero vector of size vocab_size
with a 1
at its index position. Example one-hot vector for the
assuming we have 10 unique words and the index for the
is 2.
The one-hot vectors are useful because we can use the embedding layer to reduce the dimension and create semantic relationships among words (e.g. “dog” and “cat” are close to each other in the embedding space). We use matrix multiplication with an embedding matrix to retrieve token’s embedding.
- is represented as a matrix with dimensions (
embed_size
xvocab_size
), whereembed_size
represents the dimensionality of the embeddings andvocab_size
represents the number of unique tokens in the vocabulary. - Recall that one-hot vector is represented as a row vector with dimension (
vocab_size
),
The multiplication product of and is the embedding vectors for input tokens. Thanks to JAX’s batch support, this logic is as easy as one line of code:
make_sequence_embedding = jax.vmap(lambda w_embeds, one_hot_word: w_embeds @ one_hot_word, in_axes=(None, 0))
Transformations and Activations
After we get the embedding (inputs), it’s pretty straightforward if we are in the MLP world: just multiply the inputs and weights, run through activations, until the last output layer. In the RNN world, it’s a bit more complicated, due to “Recurrent”.
We use tanh
as the hidden layer’s activiation function.
Additional notes on the notation below:
- : Hidden state at time step
- : Input at time step
Let’s calculate the forward pass step by step:
- At
t=0
(first token)
Typically we use a zero vector for the “previous hidden layer” for the firs token, which makes the forward pass look like an MLP forward pass:
- At
t=1
(second token)
This is where things get more interesting. Not only we look the current token, we also look at the hidden state from the previous token. The idea that the hidden state carries historical information that can be used to predict the next token. The hidden state is carried over.
- At
t=m
(m+1 th token)
Let’s wrap it in JAX code:
def __call__(self, sent):
embedding = make_sequence_embedding(self.w_embedding, sent)
initial_h = jnp.zeros((self.hidden_size,))
def f(h_prev, word_embedding):
h_curr = jax.nn.tanh(
self.w_h_hprev @ h_prev
+ self.w_h_embedding @ word_embedding
+ self.h_bias
)
curr_outputs = self.w_output_h @ h_curr + self.output_bias
return h_curr, curr_outputs
_, out = jax.lax.scan(f, initial_h, embedding)
return jnp.array(out)
jax.lax.scan
is the magic function that handles the forward propagation throught t
. It carries h_t
over time. Read more on it here: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
Loss
Recall that each t has their own target (next token) so we do get cross entroy loss for each output
. Thanks to JAX, the batch version is pretty simple too:
cost_func = jax.vmap(optax.softmax_cross_entropy, in_axes=(0, 0))
Backward Propagation
Well, this is why we love computation graph and JAX. The Backward Propagation, if implemented with numpy, will be fairly complicated. With JAX and optax, it just “happens”:
opt = optax.chain(
optax.clip(1),
optax.adamw(learning_rate=lr),
)
We use clip
and adamw
to avoid gradient explosion.
Here is what happens in the fit
call
loss, grads = filter_jit_batched_loss_and_gradient(self, batched_inputs, batched_targets)
avg_grads = jax.tree_map(lambda g: g.mean(axis=0), grads)
updates, opt_state = opt.update(avg_grads, opt_state, params=self)
self = eqx.apply_updates(self, updates)
Does it work?
We can the training loss trending down as more batches come through it. We can improve the initialization and tune hyperparameters to improve the optimization process.
The loss reduction curve looks a bit suspicous since it’s not very smooth. I could probably pick better hyperparameters or initial values. I’ll leave it for future work.
It’s far from ending
So far we have only covered the fitting of the network, but there are many other interesting things:
- What does the the embedding matrix look like now? we can take a look at a few words and check their cosine similarities.
- How can we use the language model to generate text?
- What are the probelms of RNN?
- What are the evaluation metrics?
More on them later 🙂