jwp@home

Training Transformers A Review

It has been a while since I’ve trained models from scratch, but now that I am stepping back into the AI bubble, I wanted to doa nice little exercise to recalibrate my intuition on training neural networks. And since transformers seems to be everywhere these days, I figured I’d try to train a mini GPT-2 model from scratch.

I imagine this website will be somewhat of a living document that I’ll update as I learn more, not just on ML but also on tech required to usher in the age of abundance. Cheesy? Yes, but its coming. Expect more educational(for myself) posts like this from me in the future as I learn stuff to build my projects.

This post and soon coming post in particular contains my notes and learnings about language models from and beyond the original Transformer paper (Vaswani et al 2017).

Note: This post is not intended to replace reading the original papers. I won’t include all details from the transformer, gpt-2, and gpt-3 papers, only mentioning details when they actually matter. While I hope this serves as a helpful note, please let me know if I’ve missed or misunderstood any important details via twitter.

Overview

In this post, I’ll try to train a mini GPT-2 model from scratch with a single GPU. I’ll cover the following:

  • Describe my data and tokenization approach
  • Go over the intuition behind attention mechanisms with examples
  • Explain how positional encoding fits in so the model knows token order
  • Discuss regularization implementation details (dropout, layernorm, residual connections)
  • Talk about training issues I encountered and their solutions

Data and Tokenization

I was able to quickly train a small proof-of-concept using a toy dataset of short phrases like “car drives on”, “phones connect people”, etc. This led to some interesting problems I’ll discuss later.

With this the model produced a degenerate attention pattern and a trivial solution where each token only attended to itself(one hot vector attention vectors for each tokens). Foolishly enough I tried debugging the model until I realized the implementation wasn’t the issue; the problem was simply a lack of data.

So then I downloaded the C4 from the dataset repo, and given that I’m training a mini LLM selected only the first 300,000 sentences. I tried these two tokenizers

Tokenization Approaches

I experimented with two main approaches:

1. Custom Word-Based Tokenizer

This simple approach:

  • Counted word frequencies in the dataset
  • Selected top 50k most frequent words as vocabulary
  • Mapped everything else to <unk>

This actually ended up in a decent model that learnt syntax and grammar pretty well. However, it predicted too many tokens for any word not in the top 50k, which was quite a lot(there were around 160k unique vocab in the dataset). The model would keep generating things like 'This is a nice ', 'Lets go to the ' ,etc.

2. BPE Tokenizer

BPE Tokenizer To fix this, I switched to Byte-Pair Encoding (BPE). I trained it on my corpus (a subset of C4). For about 300k sentence samples, BPE ended up producing ~143M tokens total. 50k vocab BPE works by merging the most frequent character pairs, creating subword units so that it drastically reduce unrecognizable tokens. By training a BPE on my data, I got a tokenizer better aligned with the text I’m learning from.

def train_bpe_tokenizer(sentences, vocab_size=50000):
    """Train a BPE tokenizer on the given sentences"""
    tokenizer = Tokenizer(BPE(unk_token="<unk>"))
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<pad>", "<unk>", "<s>", "</s>"],
    )
    tokenizer.pre_tokenizer = Whitespace()
    
    # Train the tokenizer
    tokenizer.train_from_iterator(sentences, trainer)
    return tokenizer

I also tried the tiktoken gpt2 tokenizer, which was trained on a different corpus. Using a mismatched tokenizer produced weird partial decodes so the model wasn’t learning properly.

In the final training scripts, I mostly rely on this BPE tokenizer. Whenever the model sees a sentence, it breaks it into subwords, maps them to IDs, and any truly unknown text can be spelled out via partial merges.

Attention Folks

There are probably much better explanation of attention mechanism elsewhere (i.e ask chatGPT) but I’m going to try my best here.

At a high level, attention lets each token communicate with every other token in the sequence to gather context.

Think of a sentence: “Car drives on”. If I want to predict the next token (“roads,” for example), the word “on” needs to figure out what it’s referencing. “On” by itself is ambiguous - it could mean “turned on,” “switched on,” “located on,” etc. But with attention, the token “on” can look back at “car” and “drives” to gather that the context is about a car driving, so the next token is likely “roads,” or something similar. Pretty simple concept right, but the question is how to know exactly where to look?

Below is a very simplified attention mechanism implementation where I explain each part below.

# Part A. x is the input of shape (B,S) where B is the number of batches, S is the sequence length(number of tokens) of each batch. 
# You get the query, key, and value vector for each token for each batch, which results in a 
# Q, K, (B, S, embed_dim) matrix. 
# w_q, w_k, w_v are all learnable weights 
Q = w_q(x) 
K = w_k(x)
V = w_v(x)

# Compute the attention scores for each token with Q dot K.T
d_k = Q.size(-1)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # (B, S, S)

# Part B. the causal mask to zero out attention scores future tokens
# for each token, you ignore the attention scores of tokens in front 
mask = ~torch.triu(torch.ones(S,S), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(x.device)
scores = attention_scores.masked_fill(mask == 0, float('-inf'))

# Part C. this makes the attention scores sum to one for every token
attn_weights = torch.softmax(attention_scores, dim=-1)  # (B, nH, S, S)

# the resulting output is, for each token vector, a new token vector that encapsulates the context
output = torch.matmul(attn_weights, V)        # (B, nH, S, head_dim)

# this output then just goes through a linear layer to predict the next token logits

A. Each token’s embedding can be used to create a query and a key vector. For a given token, the query asks: “Among these other tokens in the sequence, whose information do I need?”, after which the keys of other tokens answer, “I have certain context you need”. This is done by obtaining the similarity of the query and the key, which is why there is the dot product computation of q and k vectors comes in.

Think of it like a smart search engine:

  • Query (q) = “What you’re looking for”
  • Key (k) = “Labels/tags describing each piece of content”
  • The more similar q and k are, the more relevant that content is

You can loosely apply this analogy here. For the given words “Car drives on”, the similarity scores of each words in the sentence for every word needs to be computed, which would result in a 3x3 matrix. The first row consists of the similarity scores of the word ‘Car’ between the words ‘Car’(yes, with itself too), ‘drives’, ‘on’ and the second row consists of the similarity scores of the word ‘drives’ for ‘Car’, ‘drives’, ‘on’, and lastly similarly for the ‘on’ token for the third row. This is called(the magnitude of) the attention matrix, represented as QK^T, which is then scaled by the dimensionality of the Q(or K, its the same) matrix so that the subsequent softmax doesn’t blow stuff up.

B. Also, in a decoder-only model like GPT-2, each token can only attend to the tokens before it (using the causal mask). Effectively, this exists so that the model can practice predicting the future based on the past, without peaking into the future. So a mask is applied to the attention matrix so that when multiplied with the V matrix it will zero out all future tokens(no words will attend to its future words).

C. After the (normalized via softmax) attention scores are computed, this needs to be multiplied by the actual token embeddings(the V matrix which consists of an embedding vector v for every tokens), so that each [word vector] v_i updates itself to [some new vector] v_i_prime that encapsulates the context of the sequence. When updating to v_i_prime, the attention scores dictate how much each of the past tokens in that sequence needs to be attended.

For example, in the case of the token ‘on’, that [some new ‘on’ vector] would be a linear combination of ‘Car’, ‘drives’ and ‘on’ embedding vectors, such that the final ‘on’ vector would encapsulate the entire context and bemuch better equipped to predict the next token ‘road’.

Parallelism: The beauty of attention is that, even though we do next-token prediction based on all previous context, at training time the model can process the entire sequence in parallel. This is more efficient than recurrent models (RNNs) that process one token at a time recursively, making it hard to parallelize and thus train them at scale.

Multi-Headed Attention

Although omitted in the code example above, sufficiently large transformers are trained by splitting the embedding dimension into multiple heads. Each head learns to output different attention weights so that each learning can be more granularly defined. For instance one head can learn about the subject-verb relationships, the other can learn about grammar, etc.

This simply adds another batch operation to the mat muls above. We then concatenate the heads’ outputs into one representation before passing it through a final linear layer.

Positional Encoding

Transformers are permutation-invariant if we just do attention alone. That means the set of tokens [car, drives, on] is the same as [drives, car, on] to the raw attention mechanism unless we inject positional information.

Positional encoding is how we do this, where for each token vector, we add position information vectors to the token embeddings. A common technique is the sinusoidal approach. Now the model can distinguish “car drives on” from “drives car on,” ensuring “on” knows how to make sense of the context. In the gpt-2 paper, the author states that separate learnable weights are used for this encoding vectors but for simplicity I just used the sinusoidal implementation.

Regularization Techniques

Regularization is in general important for stabilizing training and preventing overfitting. In Transformers, we typically use:

  1. Dropout: I applied dropout to attention weights and feed-forward sub-layers. This randomly zeroes out a fraction of values to reduce over-dependence on particular neurons.
  2. Residual Connections: Each sub-layer does x = x + sublayer_output, letting the model preserve the original embedding while layering on transformations. This helps gradient flow and prevents vanishing or exploding gradients.
  3. LayerNorm: Normalizes each token’s embedding vector to have stable mean/variance. Typically used before or after each attention/feed-forward sub-layer, e.g., pre-norm or post-norm.

Pre-Norm vs. Post-Norm

A side note: The original attention paper applied layer norm after the attention layer, while GPT-2 typically does pre-norm(apply before attention). Some also add a final layernorm after the last layer, as did I in my implementation.

Training Issues I ran into and Solutions

  1. Degenerate Attention Matrices

• When I worked with tiny datasets, sometimes each token just attended to itself or the first token in the sequence(one-hot vectors). That’s a valid “shortcut” solution if the task is trivial, but wasn’t able to generalize to unseen data at all(there was a huge training and validation loss discrepancy). Solution: I simply added more data (C4 subset of 300k samples, ~150M tokens after BPE). This forced the model to learn more generalizable relationships.

  1. Exploding Loss

Occasionally, the training loss would blow up after a few hundred epochs. Solution: Again, this only happened when the dataset was tiny. All the gradient clipping or a better learning-rate scheduling did not help until I simply added more data.

clip grad norm

  1. Slow or Minimal Loss Convergence

• I found increasing max_len to 50 or 100 significantly improved context modeling, as the network sees up to 49 prior tokens instead of 7.

Also, ensuring the BPE vocab is well-trained on the domain means fewer tokens and a smoother time for the model.

Recognize a pattern?

Conclusion

Overall, it was quite refreshing to observe the scaling laws manifest itself quite strongly even at tiny scales like this. I got the jist of the implementation in a few hours but I coudn’t quite get the model to work besides predicting some basic grammar structures. It was only after I let go trying to understand and debug the complex training dynamics of neural network and simply made the model bigger and added more data, where everything kind of fixed itself.

Putting this a bit more bluntly: there’s a great deal of work out there which claims to lay out “how chatGPT works”, where you work through some exercises and feel like you have a solid grasp of the fundamental technology at play. However, generally speaking these things will not help you understand how the actual thing works1. The toy examples are a vehicle for you to understand the thing that eventually became chatGPT, and in your head you just impute some vague idea that they made it huge and trained it on thousands of GPUs.

This is just pretraining, not chatgpt. The pretraining phase is where they do most of the heavy lifting, and it’s the step most tutorials teach you about when you’re learning about language models in a pedagogical setting

all the things i didint cover: all post training stuff, mixed precision computing, hardware optimization, kv cache

back when i was training nn it did not work at scale, industry

To really get how to go from this sort of technology to a bonafide frontier LLM, you run into two problems.

Also found out that Tokenization matters a lot - i.e training your own BPE vs using some off the shelf tokenizer.

• Attention is just a fancy mechanism for queries to find the relevant keys/values in the sequence. The dot products measure similarity, and we apply a softmax to get weighting. • Multi-head setups let different heads learn different relational patterns. • Regularization (dropout, layernorm, residuals) is non-negotiable, especially as you scale the model or data. • More data cures a host of problems like degenerate attention and random explosive losses - assuming the model can handle it.

Now that I have a stable mini GPT-2, I’m thinking of looking into larger-scale experiments, better computation efficiency

You can check out my implementation in this google colab notebook. I had the pro version so I ran it with A100, and it converges in around 5~10 hours. It gets there by spending about 50 credits which is about 5$.

Graph

Attention Folks

Mathematical Formulation

For a more precise formulation, the scaled dot-product attention is defined as:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

where:

\[\begin{align*} Q &= \text{query matrix} \\ K &= \text{key matrix} \\ V &= \text{value matrix} \\ d_k &= \text{dimension of key/query vectors} \end{align*}\]

 ”In our transformer, the query 𝑄 Q, key 𝐾 K, and value 𝑉 V are obtained using learned projections, and their similarity is scaled by the square root of the key dimension, 𝑑 𝑘 d k ​ .”