6.4. Large Language Model#
Since the introduction of ChatGPT in 2022, large language models (LLMs) have gained significant public attention as a prominent application of artificial intelligence. In this tutorial, we will delve into the underlying operations of modern LLMs, and we’ll implement a basic language model that can generate text based on a small dataset.
At a fundamental level, language models can be pretrained to predict either the continuation of a text segment or fill in missing parts of the segment. These two main types are:
Autoregressive models: Predict the continuation of a segment. For example, given the prompt “I like to eat,” the model might predict “pizza” or “ice cream.”
Masked models (also called ‘cloze’ models): Fill in the missing parts of a segment. For example, given “I like to
[__]
[__]
cream,” the model might predict that “eat” and “ice” are the missing words.
In this notebook, we will implement the autoregressive approach. This means that our model will focus on predicting the next “token” (or smallest unit of meaning) in a sequence, based on prior context.
LLMs are statistical models. They learn, from large sets of internet text, the probability of the next token given the preceding context. This training is self-supervised, meaning they learn from the data itself without needing labelled answers.
The presented materials here are inspired by Andrej Karpathy’s excellent work on NanoGPT, and I highly recommend his lecture on the subject.
0. Preparation#
Let’s start by importing the necessary libraries and setting a random seed for reproducibility.
import numpy as np
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
# Set a fixed random seed for reproducibility
torch.manual_seed(1365)
<torch._C.Generator at 0x762f3a8c4c50>
1. Dataset#
We start with loading the dataset to understand what our model will train on. Here, we’ll use the “Tiny Shakespeare” dataset, which compiles Shakespeare’s works into a single text file.
# Create a folder for the data, if it doesn't already exist
os.makedirs("data/", exist_ok=True)
# Download the Tiny Shakespeare dataset into our data folder
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O data/tinyshakespeare.txt
--2024-11-11 08:31:14-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1,1M) [text/plain]
Saving to: ‘data/tinyshakespeare.txt’
data/tinyshakespear 100%[===================>] 1,06M --.-KB/s in 0,05s
2024-11-11 08:31:15 (22,6 MB/s) - ‘data/tinyshakespeare.txt’ saved [1115394/1115394]
# Read and display the dataset
with open('data/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
db_text = f.read()
print(f"Number of characters in the dataset: {len(db_text)}")
Number of characters in the dataset: 1115394
The Tiny Shakespeare dataset has over 1.1 million characters. Let’s print the first 250 characters to get a sense of the text.
# Display the first 250 characters
print(db_text[:250])
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
Tokens#
Our language model works by predicting the next token based on the preceding context. A token is simply a unit of text that the model understands and processes. In this tutorial, we’ll use individual characters as tokens to keep things straightforward. This means that the model will learn to predict the next character based on previous characters in the sequence.
In larger language models, tokens can represent more complex units, such as entire words or subwords. Word-level tokenisation can be useful for capturing meaning in longer pieces of text. For example, a phrase like “natural language processing” might be divided into words as tokens, rather than individual letters. You can read more about this in the text-classification notebook.
In this project, however, we’ll keep it simple and focus on character-level tokens. This approach allows us to train a smaller model while still learning basic patterns and sequences within text.
To proceed, let’s identify all unique characters in our dataset and assign each one an integer code.
# Unique characters in the text
db_chars = sorted(list(set(db_text)))
vocab_size = len(db_chars)
print("Unique characters in the dataset:", ''.join(db_chars))
print("Total number of unique characters:", vocab_size)
Unique characters in the dataset:
!$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Total number of unique characters: 65
Tokeniser#
Computers only understand numbers. For instance, in images, colours are represented by numbers (e.g., 0 represents black, and 255 represents white). Text data follows the same principle: characters need to be converted into numbers that the model can work with. Therefore, we’ll convert each character into a unique integer. This is essential because our model operates on numerical data, not raw text.
To do this, we’ll create a tokeniser that assigns a unique integer to each character in the dataset. We achieve this by iterating through all characters, assigning an integer to each one in the order they appear in the dataset. This mapping is implemented in the str2int
dictionary, which maps characters to integers, and int2str
, which maps integers back to characters.
Next, we define two lambda
functions—encode_txt
and decode_txt
—to handle conversions between lists of characters and their integer representations. The encode_txt
function takes a string and returns a list of integers, while decode_txt
does the reverse, converting a list of integers back to text.
# Map characters to integers
str2int = {ch: i for i, ch in enumerate(db_chars)}
int2str = {i: ch for i, ch in enumerate(db_chars)}
# Functions for encoding and decoding
encode_txt = lambda s: [str2int[c] for c in s] # Encode: converts string to list of integers
decode_txt = lambda l: ''.join([int2str[i] for i in l]) # Decode: converts list of integers to string
print(encode_txt("hii there"))
print(decode_txt(encode_txt("hii there")))
[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there
We can see in our mapping that each character has a unique numerical code. For example, h
is mapped to 46
, i
to 47
, and the space (
) to 1
. To illustrate how our language model views the data, let’s print the first 250 characters as numbers to see how it would appear to the model:
print(encode_txt(db_text[:250]))
[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1, 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49, 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50, 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58, 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47, 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42, 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63, 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41, 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63, 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0]
In the original text, you may have noticed that there are line breaks between different parts of the conversation (for example, between dialogue by “first-citizen” and “All”). How are these line breaks represented numerically? We can check the encoding for the line break symbol \n
in our tokeniser:
encode_txt('\n')
[0]
In our list of numbers above, the number 0
represents a line break. This allows the language model to understand structural elements in the text, such as new lines, even though it processes everything as numbers.
PyTorch Dataset Preparation#
Before training a language model in PyTorch, we need to prepare our text data by converting it into a format that PyTorch can process. Specifically, we’ll convert our dataset to a PyTorch torch.Tensor
, which will hold the data as numerical values that the model can use.
Convert text to a tensor: Using our earlier tokenisation, we’ll encode the entire text dataset into numbers and store it as a
torch.Tensor
.
# Encode the text as integers and store in a tensor
db_tensor = torch.tensor(encode_txt(db_text), dtype=torch.long)
# Print the shape and data type of the tensor to confirm
print(db_tensor.shape, db_tensor.dtype)
torch.Size([1115394]) torch.int64
Split data into training and validation sets: To train effectively, we need a training set (90% of the data) and a validation set (10% of the data) to monitor the model’s performance.
# Split the tensor into training (first 90%) and validation (last 10%) sets
n = int(0.9 * len(db_tensor))
train_data = db_tensor[:n]
val_data = db_tensor[n:]
Autoregressive Model#
An autoregressive model predicts each character based on the sequence of preceding characters. This setup means that, for each step, the model uses the characters it has already seen to predict the next one. Let’s walk through how this works.
Define a context length (block size): The context length (or block size) defines how many previous characters the model considers to predict the next character. Here, we set
block_size
to 12, meaning the model will look at the past 12 characters when making each prediction.
block_size = 12
print(f"LLM sees: {train_data[:block_size+1]}")
print(f"Human sees: {decode_txt(train_data[:block_size+1].tolist())}")
LLM sees: tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52])
Human sees: First Citizen
Autoregressive training example: During training, the model’s goal is to predict the next character based on the sequence it has just seen. For instance, if the model sees the character “F” (encoded as
18
), it should predict the next character, “i” (encoded as47
).
Let’s print each input character and its expected target character across our chosen block of 12 characters.
# Define input and target sequences for our block size
x = train_data[:block_size]
y = train_data[1:block_size + 1]
max_width = block_size * 5 # Set max width for aligned printing
# Display the input and target character-by-character
for t in range(block_size):
context = x[:t + 1]
target = y[t]
context_str = f"{context}".ljust(max_width) # Left-justify with padding
print(f"input: {context_str} target: {target}")
input: tensor([18]) target: 47
input: tensor([18, 47]) target: 56
input: tensor([18, 47, 56]) target: 57
input: tensor([18, 47, 56, 57]) target: 58
input: tensor([18, 47, 56, 57, 58]) target: 1
input: tensor([18, 47, 56, 57, 58, 1]) target: 15
input: tensor([18, 47, 56, 57, 58, 1, 15]) target: 47
input: tensor([18, 47, 56, 57, 58, 1, 15, 47]) target: 58
input: tensor([18, 47, 56, 57, 58, 1, 15, 47, 58]) target: 47
input: tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47]) target: 64
input: tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64]) target: 43
input: tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43]) target: 52
Human-readable format: To make the output clearer, let’s print the characters in their original text format.
# Display in human-readable form
x = train_data[:block_size]
y = train_data[1:block_size + 1]
max_width = block_size # Set width for readability
for t in range(block_size):
context = x[:t + 1]
target = y[t]
context_str = f"{decode_txt(context.tolist())}".ljust(max_width) # Left-justify with padding
print(f"input: {context_str} target: {int2str[target.tolist()]}")
input: F target: i
input: Fi target: r
input: Fir target: s
input: Firs target: t
input: First target:
input: First target: C
input: First C target: i
input: First Ci target: t
input: First Cit target: i
input: First Citi target: z
input: First Citiz target: e
input: First Citize target: n
In PyTorch, a typical approach to handle batching is through a Dataset
class (inheriting from torch.utils.data.Dataset
) and DataLoader
(from torch.utils.data.DataLoader
). However, in this simple setup, we’ll use a function called get_batch
to generate batches of data for us.
The get_batch
function works as follows:
It takes a split (
'train'
or'val'
) and selects either the training or validation dataset.It then randomly selects starting indices for sequences of length
block_size
and extracts corresponding input (x
) and target (y
) sequences.x
represents the input batch, andy
represents the target batch, shifted by one token, meaning that eachx[i]
sequence predicts the followingy[i]
token.
def get_batch(split, device='cpu'):
# generate a small batch of data of inputs x and targets y
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
Let’s set up a batch to see its structure:
batch_size = 4 # number of sequences processed in parallel
block_size = 12 # maximum context length for predictions
batch_data, target_tokens = get_batch('train')
print(f"Size of input data in the batch: {batch_data.shape}")
print(f"Size of target in the batch: {target_tokens.shape}")
Size of input data in the batch: torch.Size([4, 12])
Size of target in the batch: torch.Size([4, 12])
2. Network#
We’ll implement two types of models:
Bigram Language Model: A simple n-gram model that considers only the previous token for prediction.
GPT-based Language Model: A more complex transformer-based model similar to those used in state-of-the-art LLMs like ChatGPT and Gemini.
Bigram Language Model#
A bigram language model is an example of an n-gram model where n = 2
, meaning it considers only one previous token to predict the next one. If we considered two tokens, we would have a trigram model. This type of model learns dependencies based on pairs of tokens.
The BigramLanguageModel
class is built with:
Initialising the Model:
The model is a subclass of
nn.Module
, which provides essential methods to manage model behaviour.token_embedding_table
: Thisnn.Embedding
layer maps each token in our vocabulary to a vector of values that the model can optimise during training. Here, we’re using a square matrix of sizevocab_size x vocab_size
, which lets the model learn relationships between each pair of tokens directly similar to a lookup table.
Forward Pass (
forward
method):This method takes
input_tokens
, which is a batch of sequences of token indices, and passes them through the embedding layer to getlogits
. Here,logits
represent the raw output scores for each token, before applying softmax.
Calculating the Loss (
calculate_loss
method):The model uses the
cross_entropy
loss function to calculate negative log-likelihood, which measures how well the model’s predictions (logits) match the target tokens.To use cross-entropy, we flatten the
logits
andtarget_tokens
tensors into two-dimensional arrays. This allows the function to calculate the loss across all tokens in the batch at once.The
staticmethod
decorator indicates that a method is associated with the class itself, rather than any particular instance of the class. This is useful if, for instance, we want to have a standardized way of computing loss that applies universally across all instances of the model or if we want to call the loss calculation outside the class instance.
Generating a Sequence (
generate_sequence
method):The
generate_sequence
method starts with a sequence (such as a single token) and generates new tokens one by one.In each step, it:
Passes the input sequence through the model.
Focuses only on the logits for the most recent token (last in sequence).
Applies a
softmax
function to convert logits to probabilities.Uses
torch.multinomial
to sample the next token based on the probabilities, which introduces an element of randomness.Adds the new token to the sequence and repeats until reaching the maximum specified length.
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__() # Initialise parent class (nn.Module)
# Embedding table that maps each token to an embedding vector
# This table is of size (vocab_size, vocab_size) so each token can "read" the logits of the next token
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
def forward(self, input_tokens):
"""
Forward pass: This function processes input data through the model.
Args:
input_tokens (Tensor): Batch of input sequences with shape (batch_size, sequence_length).
Returns:
Tensor: Output logits with shape (batch_size, sequence_length, vocab_size).
"""
# Pass input tokens through embedding layer
logits = self.token_embedding_table(input_tokens) # (batch_size, sequence_length, vocab_size)
return logits
@staticmethod
def calculate_loss(logits, target_tokens):
"""
Calculate cross-entropy loss comparing predicted logits to actual target tokens.
Args:
logits (Tensor): Model output logits of shape (batch_size, sequence_length, vocab_size).
target_tokens (Tensor): Ground truth tokens of shape (batch_size, sequence_length).
Returns:
Tensor: Calculated cross-entropy loss.
"""
batch_size, sequence_length, vocab_size = logits.shape # Unpack tensor dimensions
logits = logits.view(batch_size * sequence_length, vocab_size) # Flatten for cross-entropy
target_tokens = target_tokens.view(batch_size * sequence_length) # Flatten for cross-entropy
loss = F.cross_entropy(logits, target_tokens) # Calculate cross-entropy loss
return loss
def generate_sequence(self, input_tokens, max_length):
"""
Generate a sequence by predicting one token at a time based on previous tokens.
Args:
input_tokens (Tensor): Initial token to start generating from, of shape (batch_size, 1).
max_length (int): Number of new tokens to generate.
Returns:
Tensor: Generated token sequence.
"""
for _ in range(max_length):
# Run the input through the model to get logits (predictions)
logits = self.forward(input_tokens)
# Get only the logits for the last token position
logits = logits[:, -1, :] # (batch_size, vocab_size)
# Apply softmax to convert logits to probabilities
probabilities = F.softmax(logits, dim=-1)
# Randomly select the next token based on the probabilities
next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
# Add the new token to the input tokens sequence
input_tokens = torch.cat((input_tokens, next_token), dim=1) # (batch_size, sequence_length + 1)
return input_tokens
Let’s create an instance of BigramLanguageModel
and check its structure. The Embedding
matrix has dimensions of 65 by 65, reflecting the number of unique characters in our dataset.
# Instantiate the model
bigram_net = BigramLanguageModel(vocab_size)
print(bigram_net)
BigramLanguageModel(
(token_embedding_table): Embedding(65, 65)
)
Let’s input one batch of data into the model to observe its behaviour. Here, we can check the shape of the output, which should have three dimensions: (batch_size, sequence_length, vocab_size)
.
output_logits = bigram_net(batch_data)
print(f"Model's output size: {output_logits.shape}")
Model's output size: torch.Size([4, 12, 65])
We can now calculate the initial loss to see how well (or poorly) the untrained model performs. For comparison, we can calculate a “chance-level” baseline loss, which would represent a model that predicts tokens randomly.
# Calculate and print the loss
initial_loss = bigram_net.calculate_loss(output_logits, target_tokens)
print(f"Initial loss (untrained model): {initial_loss:.3f}")
# Calculate the chance-level baseline
chance_level_loss = -np.log(1 / vocab_size)
print(f"Chance-level baseline loss: {chance_level_loss:.3f}")
Initial loss (untrained model): 4.630
Chance-level baseline loss: 4.174
Finally, let’s use our untrained model to generate a sequence of tokens. We expect the output to be somewhat random, as the model hasn’t learned any patterns yet.
# Generate a text sequence from the initial token
start_token = torch.zeros((1, 1), dtype=torch.long) # Start with token "0"
generated_sequence = bigram_net.generate_sequence(start_token, max_length=100)
# Convert generated token indices back to characters
generated_text = decode_txt(generated_sequence[0].tolist())
print(generated_text)
QwuiVZZEKzSKlV.ATrRlzEaV?3ZBWApyiBkQdtNAz
uqMVCD.jNMGgDmC&OuoDLYpVu
uMTClnrk,AaIagaUx 'zkl,ATe
?csZ&
The output will show a random sequence of characters, which is expected at this stage.
Transformer Architecture and Self-Attention#
To create a modern language model like GPT, we need to use a transformer architecture. Introduced in 2017 in the influential paper “Attention is All You Need”, transformers have become a versatile architecture applied across various types of data, from text to images and beyond. The key operation in transformers is self-attention, which is referred to as “Scaled Dot-Product Attention” in the original paper, defines as:
\(Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V,\)
where \(Q\), \(K\), and \(V\) denote the query, key and value, respectively, and \(d_k\) is the dimensionality of keys.
Attention can be thought of as a communication system where tokens in a sequence “look” at each other and learn to focus on certain tokens based on relevance. Imagine each token as a node in a directed graph, where each node collects information from the others, weighted by how important each connection is. This weighted communication allows the model to learn relationships between tokens in a sequence, forming the basis of powerful language models.
The following code implements key components of the transformer architecture: SelfAttention
, MultiHeadAttention
, MLP
, and TransformerBlock
. Let’s go through each one.
Self-Attention#
In transformers, self-attention enables each token to focus on other tokens in the sequence, calculating a unique attention weight for each possible pair of tokens. In the code below, we define a SelfAttention
class that represents a single “head” of self-attention.
class SelfAttention(nn.Module):
"""One head of self-attention: calculates attention for each token in relation to others."""
def __init__(self, head_size, embedding_dim, dropout_rate):
super().__init__()
# Linear transformations for computing the key, query, and value matrices
self.key = nn.Linear(embedding_dim, head_size, bias=False)
self.query = nn.Linear(embedding_dim, head_size, bias=False)
self.value = nn.Linear(embedding_dim, head_size, bias=False)
# Create a lower-triangular mask for future tokens (causal mask)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
"""
Forward pass for self-attention head.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).
Returns:
Tensor: Output tensor of shape (batch_size, sequence_length, head_size).
"""
batch_size, sequence_length, embedding_dim = x.shape
# Calculate key, query, and value matrices
keys = self.key(x) # Shape: (batch_size, sequence_length, head_size)
queries = self.query(x) # Shape: (batch_size, sequence_length, head_size)
# Compute attention scores by taking dot product of queries and keys
# Scaled by square root of head_size to maintain stable gradients
attention_scores = queries @ keys.transpose(-2, -1) * (embedding_dim ** -0.5)
# Apply causal mask to prevent attention to future tokens
attention_scores = attention_scores.masked_fill(self.tril[:sequence_length, :sequence_length] == 0, float('-inf'))
# Convert attention scores to probabilities
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
# Calculate weighted sum of values
values = self.value(x)
output = attention_probs @ values # Shape: (batch_size, sequence_length, head_size)
return output
Key, Query, and Value matrices: Each input token is transformed into these three representations. The dot product of
queries
andkeys
produces an attention score for each token pair.Causal Mask: We use a triangular mask so each token can only attend to previous tokens, ensuring future information isn’t used when predicting the next token.
Softmax and Weighted Sum: We apply softmax to convert the scores to probabilities, which are then used to calculate a weighted sum of the
values
.
Multi-Head Attention#
To enhance the model’s capacity to learn complex relationships, transformers use multi-head attention, which runs several self-attention heads in parallel. Each head learns different aspects of the relationships between tokens.
class MultiHeadAttention(nn.Module):
"""Combines multiple self-attention heads in parallel."""
def __init__(self, num_heads, head_size, embedding_dim, dropout_rate):
super().__init__()
# Initialise multiple self-attention heads
self.heads = nn.ModuleList([SelfAttention(head_size, embedding_dim, dropout_rate) for _ in range(num_heads)])
# Project concatenated output of all heads back to embedding dimension
self.proj = nn.Linear(head_size * num_heads, embedding_dim)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
"""
Forward pass for multi-head attention.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).
Returns:
Tensor: Output tensor of shape (batch_size, sequence_length, embedding_dim).
"""
# Concatenate outputs from each head along the last dimension
multi_head_output = torch.cat([head(x) for head in self.heads], dim=-1)
# Apply final linear projection and dropout
output = self.dropout(self.proj(multi_head_output))
return output
Here, MultiHeadAttention
combines the output of each self-attention head and applies a final projection layer to bring it back to the original embedding dimension.
Multilayer Perceptron (MLP)#
Following the attention layers, transformers apply a simple neural network called an MLP (Multilayer Perceptron), which learns further transformations on the data.
class MLP(nn.Module):
"""Defines a feedforward neural network (MLP) for additional processing after attention."""
def __init__(self, embedding_dim, dropout_rate):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embedding_dim, 4 * embedding_dim), # Expand the embedding dimension
nn.ReLU(),
nn.Linear(4 * embedding_dim, embedding_dim), # Project back down
nn.Dropout(dropout_rate),
)
def forward(self, x):
"""
Forward pass for the MLP.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).
Returns:
Tensor: Processed tensor of the same shape.
"""
return self.net(x)
This MLP increases the dimensions temporarily before reducing them back, allowing for more complex transformations.
Transformer Block#
Finally, we combine self-attention and MLP layers into a Transformer Block. This block is the core building unit of transformers, and each one includes both self-attention and feedforward (MLP) processing with layer normalisation applied to stabilise learning.
class TransformerBlock(nn.Module):
"""Defines a single transformer block with self-attention and MLP layers."""
def __init__(self, embedding_dim, num_heads, dropout_rate):
super().__init__()
head_size = embedding_dim // num_heads
self.attention = MultiHeadAttention(num_heads, head_size, embedding_dim, dropout_rate)
self.feedforward = MLP(embedding_dim, dropout_rate)
self.norm1 = nn.LayerNorm(embedding_dim)
self.norm2 = nn.LayerNorm(embedding_dim)
def forward(self, x):
"""
Forward pass for the transformer block.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).
Returns:
Tensor: Output tensor of the same shape.
"""
# Apply multi-head attention followed by layer normalisation
x = x + self.attention(self.norm1(x))
# Apply MLP followed by layer normalisation
x = x + self.feedforward(self.norm2(x))
return x
In this block:
Multi-Head Attention and MLP layers provide the model with the ability to learn dependencies in different ways.
Layer Normalisation helps stabilise learning by scaling the data before each step.
GPT Language Model#
Now that we’ve built the foundational components of the transformer architecture, we can create the GPTLanguageModel
. This model is somewhat similar to a simpler BigramLanguageModel
, with one major distinction: instead of directly predicting the next token based on bigram frequencies, GPTLanguageModel
leverages the transformer architecture. It uses several layers of transformer blocks, each containing multiple self-attention heads followed by an MLP, enabling it to understand context over longer sequences.
The following code defines GPTLanguageModel
, which embeds tokens and positions, then passes them through the transformer layers before generating predictions.
class GPTLanguageModel(nn.Module):
"""A GPT-based language model that utilises transformer blocks to generate sequences."""
def __init__(self, vocab_size, embedding_dim=64, num_heads=4, num_layers=4, dropout_rate=0):
"""
Initialises the model with specified vocabulary size, embedding dimension, number of heads,
number of transformer layers, and dropout rate.
Args:
vocab_size (int): Size of the vocabulary.
embedding_dim (int): Dimension of token embeddings.
num_heads (int): Number of attention heads.
num_layers (int): Number of transformer layers.
dropout_rate (float): Dropout probability for regularisation.
"""
super().__init__()
# Embedding layer for token representation
self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
# Embedding layer for positional representation to add sequence information
self.position_embedding = nn.Embedding(block_size, embedding_dim)
# Stack of transformer blocks
self.transformer_blocks = nn.Sequential(
*[TransformerBlock(embedding_dim, num_heads=num_heads, dropout_rate=dropout_rate) for _ in range(num_layers)]
)
# Final layer normalisation for stable outputs
self.final_layer_norm = nn.LayerNorm(embedding_dim)
# Output layer mapping the final transformer output to vocabulary size
self.language_model_head = nn.Linear(embedding_dim, vocab_size)
# Initialise weights for stability
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialises weights for linear and embedding layers."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids):
"""
Forward pass of the model.
Args:
input_ids (Tensor): Tensor of shape (batch_size, sequence_length) with input token indices.
Returns:
Tensor: Logits of shape (batch_size, sequence_length, vocab_size) indicating probabilities of each token.
"""
batch_size, sequence_length = input_ids.shape
# Create token embeddings
token_embeddings = self.token_embedding(input_ids) # Shape: (batch_size, sequence_length, embedding_dim)
# Create positional embeddings to give a sense of order in the sequence
positions = torch.arange(sequence_length, device=input_ids.device)
position_embeddings = self.position_embedding(positions) # Shape: (sequence_length, embedding_dim)
# Add token and positional embeddings
x = token_embeddings + position_embeddings # Combined shape: (batch_size, sequence_length, embedding_dim)
# Pass through stacked transformer blocks
x = self.transformer_blocks(x) # Shape: (batch_size, sequence_length, embedding_dim)
# Apply final layer normalisation
x = self.final_layer_norm(x) # Shape: (batch_size, sequence_length, embedding_dim)
# Convert to logits for each token in the vocabulary
logits = self.language_model_head(x) # Shape: (batch_size, sequence_length, vocab_size)
return logits
@staticmethod
def calculate_loss(logits, target_ids):
"""
Calculates cross-entropy loss between predicted logits and target token indices.
Args:
logits (Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
target_ids (Tensor): Target indices of shape (batch_size, sequence_length).
Returns:
Tensor: Calculated loss value.
"""
batch_size, sequence_length, vocab_size = logits.shape
logits = logits.view(batch_size * sequence_length, vocab_size)
target_ids = target_ids.view(batch_size * sequence_length)
# Cross-entropy loss over flattened logits and target
loss = F.cross_entropy(logits, target_ids)
return loss
def generate(self, input_ids, max_new_tokens):
"""
Generates text by iteratively sampling new tokens.
Args:
input_ids (Tensor): Initial token indices of shape (batch_size, initial_sequence_length).
max_new_tokens (int): Maximum number of new tokens to generate.
Returns:
Tensor: Expanded sequence with newly generated tokens.
"""
for _ in range(max_new_tokens):
# Focus on last tokens within block size
input_ids_cond = input_ids[:, -block_size:]
# Forward pass to get logits
logits = self.forward(input_ids_cond)
# Focus only on the last token's logits
logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)
# Convert logits to probabilities using softmax
probs = F.softmax(logits, dim=-1) # Shape: (batch_size, vocab_size)
# Sample from the probability distribution to get next token index
next_token_id = torch.multinomial(probs, num_samples=1) # Shape: (batch_size, 1)
# Append sampled token to input_ids
input_ids = torch.cat((input_ids, next_token_id), dim=1) # Updated shape: (batch_size, current_length + 1)
return input_ids
Let’s go over each component in the model:
Token and Positional Embeddings: In language models, each token is represented as a vector using an embedding table. Since transformers do not inherently have a sense of order, we add positional embeddings to provide sequence information.
Transformer Blocks: The model consists of multiple transformer blocks stacked together. Each block has multi-headed self-attention and feedforward MLP layers, enabling the model to focus on different parts of the input sequence.
Final Layer Norm and Language Model Head: After passing through the transformer layers, we apply layer normalisation to stabilise the output. The language model head maps this output to logits representing the vocabulary.
Weight Initialisation: Initialising weights in specific layers stabilises training and enhances model performance.
To understand the model’s structure, we can create an instance of GPTLanguageModel
and print its layers.
# Create GPTLanguageModel instance
gpt_net = GPTLanguageModel(vocab_size=vocab_size)
# Print the model architecture
print(gpt_net)
GPTLanguageModel(
(token_embedding): Embedding(65, 64)
(position_embedding): Embedding(12, 64)
(transformer_blocks): Sequential(
(0): TransformerBlock(
(attention): MultiHeadAttention(
(heads): ModuleList(
(0-3): 4 x SelfAttention(
(key): Linear(in_features=64, out_features=16, bias=False)
(query): Linear(in_features=64, out_features=16, bias=False)
(value): Linear(in_features=64, out_features=16, bias=False)
(dropout): Dropout(p=0, inplace=False)
)
)
(proj): Linear(in_features=64, out_features=64, bias=True)
(dropout): Dropout(p=0, inplace=False)
)
(feedforward): MLP(
(net): Sequential(
(0): Linear(in_features=64, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=64, bias=True)
(3): Dropout(p=0, inplace=False)
)
)
(norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(1): TransformerBlock(
(attention): MultiHeadAttention(
(heads): ModuleList(
(0-3): 4 x SelfAttention(
(key): Linear(in_features=64, out_features=16, bias=False)
(query): Linear(in_features=64, out_features=16, bias=False)
(value): Linear(in_features=64, out_features=16, bias=False)
(dropout): Dropout(p=0, inplace=False)
)
)
(proj): Linear(in_features=64, out_features=64, bias=True)
(dropout): Dropout(p=0, inplace=False)
)
(feedforward): MLP(
(net): Sequential(
(0): Linear(in_features=64, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=64, bias=True)
(3): Dropout(p=0, inplace=False)
)
)
(norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(2): TransformerBlock(
(attention): MultiHeadAttention(
(heads): ModuleList(
(0-3): 4 x SelfAttention(
(key): Linear(in_features=64, out_features=16, bias=False)
(query): Linear(in_features=64, out_features=16, bias=False)
(value): Linear(in_features=64, out_features=16, bias=False)
(dropout): Dropout(p=0, inplace=False)
)
)
(proj): Linear(in_features=64, out_features=64, bias=True)
(dropout): Dropout(p=0, inplace=False)
)
(feedforward): MLP(
(net): Sequential(
(0): Linear(in_features=64, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=64, bias=True)
(3): Dropout(p=0, inplace=False)
)
)
(norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(3): TransformerBlock(
(attention): MultiHeadAttention(
(heads): ModuleList(
(0-3): 4 x SelfAttention(
(key): Linear(in_features=64, out_features=16, bias=False)
(query): Linear(in_features=64, out_features=16, bias=False)
(value): Linear(in_features=64, out_features=16, bias=False)
(dropout): Dropout(p=0, inplace=False)
)
)
(proj): Linear(in_features=64, out_features=64, bias=True)
(dropout): Dropout(p=0, inplace=False)
)
(feedforward): MLP(
(net): Sequential(
(0): Linear(in_features=64, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=64, bias=True)
(3): Dropout(p=0, inplace=False)
)
)
(norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
)
(final_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(language_model_head): Linear(in_features=64, out_features=65, bias=True)
)
3. Training#
Now that we have both our dataset and model architecture set up, we are ready to train the network. Training involves optimising the model’s parameters so that it can make accurate predictions on unseen data. This section includes utility functions to monitor performance during training and the main training loop that optimises the model.
Utility Functions#
The estimate_loss
function is designed to evaluate the model’s performance on both training and validation datasets. By monitoring these metrics periodically, we can gauge how well the model is learning and whether it is overfitting or underfitting.
@torch.no_grad() # Disables gradient calculations for evaluation
def estimate_loss(model, eval_iters=200):
"""
Estimates the average loss on the training and validation sets.
Args:
model (nn.Module): The model to evaluate.
eval_iters (int): Number of evaluation iterations for averaging.
Returns:
dict: Dictionary containing mean loss for 'train' and 'val' sets.
"""
losses = {} # Dictionary to store loss values
model.eval() # Sets model to evaluation mode (important for layers like dropout)
for split in ['train', 'val']:
split_losses = torch.zeros(eval_iters) # Holds loss values for each iteration
for i in range(eval_iters):
# Get a batch of data for the current split ('train' or 'val')
x_batch, y_batch = get_batch(split, device=device)
# Perform a forward pass through the model to get predictions
logits = model(x_batch)
# Calculate loss between model predictions and actual values
loss = model.calculate_loss(logits, y_batch)
# Store the loss value
split_losses[i] = loss.item()
# Calculate the mean loss for the current split
losses[split] = split_losses.mean()
model.train() # Reset model to training mode
return losses
Optimisation#
Now let’s proceed to set up the main training loop. This includes defining hyperparameters, setting the device, creating an optimiser, and iteratively adjusting the model weights based on the training data.
# Hyperparameters - these are parameters we set before training
batch_size = 16 # Number of sequences processed in parallel
context_length = 32 # Maximum context length the model considers
max_steps = 5000 # Total number of optimisation steps
eval_interval = 100 # Frequency of evaluation (in steps)
learning_rate = 1e-3 # Step size for the optimiser
# Determine whether a GPU is available, otherwise default to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instantiate the GPT language model and move it to the selected device
gpt_net = GPTLanguageModel(vocab_size)
gpt_net = gpt_net.to(device)
# Display the number of parameters in the model (in millions) for reference
print(f"Model has {sum(p.numel() for p in gpt_net.parameters()) / 1e6:.2f} million parameters")
# Set up an optimiser, which updates model parameters to minimise loss
optimizer = torch.optim.AdamW(gpt_net.parameters(), lr=learning_rate)
# Training loop - iterates over multiple steps to update model weights
for step in range(max_steps):
# Evaluate the model at regular intervals on both train and validation sets
if step % eval_interval == 0 or step == max_steps - 1:
loss_values = estimate_loss(gpt_net)
print(f"Step {step:04d}: Train Loss = {loss_values['train']:.4f}, Validation Loss = {loss_values['val']:.4f}")
# Fetch a batch of training data (input and target outputs)
x_batch, y_batch = get_batch('train', device=device)
# Forward pass: compute model predictions for the batch
logits = gpt_net(x_batch)
# Calculate the training loss for the batch
loss = gpt_net.calculate_loss(logits, y_batch)
# Clear previous gradients to prepare for new backpropagation
optimizer.zero_grad(set_to_none=True)
# Backward pass: compute gradients of the loss with respect to model parameters
loss.backward()
# Update model parameters using computed gradients
optimizer.step()
Model has 0.21 million parameters
Step 0000: Train Loss = 4.1622, Validation Loss = 4.1633
Step 0100: Train Loss = 2.6702, Validation Loss = 2.6733
Step 0200: Train Loss = 2.5020, Validation Loss = 2.5005
Step 0300: Train Loss = 2.4623, Validation Loss = 2.4609
Step 0400: Train Loss = 2.4038, Validation Loss = 2.3998
Step 0500: Train Loss = 2.3359, Validation Loss = 2.3330
Step 0600: Train Loss = 2.3081, Validation Loss = 2.3049
Step 0700: Train Loss = 2.2813, Validation Loss = 2.3080
Step 0800: Train Loss = 2.2317, Validation Loss = 2.2513
Step 0900: Train Loss = 2.2236, Validation Loss = 2.2414
Step 1000: Train Loss = 2.1879, Validation Loss = 2.2254
Step 1100: Train Loss = 2.1890, Validation Loss = 2.2017
Step 1200: Train Loss = 2.1839, Validation Loss = 2.2106
Step 1300: Train Loss = 2.1571, Validation Loss = 2.1766
Step 1400: Train Loss = 2.1315, Validation Loss = 2.1574
Step 1500: Train Loss = 2.1298, Validation Loss = 2.1645
Step 1600: Train Loss = 2.1149, Validation Loss = 2.1543
Step 1700: Train Loss = 2.0948, Validation Loss = 2.1246
Step 1800: Train Loss = 2.0930, Validation Loss = 2.1245
Step 1900: Train Loss = 2.0747, Validation Loss = 2.1126
Step 2000: Train Loss = 2.0830, Validation Loss = 2.1236
Step 2100: Train Loss = 2.0663, Validation Loss = 2.1259
Step 2200: Train Loss = 2.0432, Validation Loss = 2.0875
Step 2300: Train Loss = 2.0203, Validation Loss = 2.0790
Step 2400: Train Loss = 2.0133, Validation Loss = 2.0695
Step 2500: Train Loss = 2.0531, Validation Loss = 2.0896
Step 2600: Train Loss = 2.0054, Validation Loss = 2.0621
Step 2700: Train Loss = 2.0133, Validation Loss = 2.0705
Step 2800: Train Loss = 1.9907, Validation Loss = 2.0790
Step 2900: Train Loss = 2.0025, Validation Loss = 2.0603
Step 3000: Train Loss = 1.9779, Validation Loss = 2.0670
Step 3100: Train Loss = 1.9832, Validation Loss = 2.0736
Step 3200: Train Loss = 1.9668, Validation Loss = 2.0423
Step 3300: Train Loss = 1.9508, Validation Loss = 2.0493
Step 3400: Train Loss = 1.9756, Validation Loss = 2.0617
Step 3500: Train Loss = 1.9516, Validation Loss = 2.0476
Step 3600: Train Loss = 1.9452, Validation Loss = 2.0163
Step 3700: Train Loss = 1.9485, Validation Loss = 2.0374
Step 3800: Train Loss = 1.9499, Validation Loss = 2.0312
Step 3900: Train Loss = 1.9122, Validation Loss = 2.0228
Step 4000: Train Loss = 1.8988, Validation Loss = 2.0397
Step 4100: Train Loss = 1.9105, Validation Loss = 2.0145
Step 4200: Train Loss = 1.9201, Validation Loss = 2.0038
Step 4300: Train Loss = 1.9242, Validation Loss = 2.0216
Step 4400: Train Loss = 1.8937, Validation Loss = 2.0118
Step 4500: Train Loss = 1.9095, Validation Loss = 2.0197
Step 4600: Train Loss = 1.9087, Validation Loss = 2.0326
Step 4700: Train Loss = 1.9043, Validation Loss = 2.0039
Step 4800: Train Loss = 1.8758, Validation Loss = 2.0251
Step 4900: Train Loss = 1.8800, Validation Loss = 1.9979
Step 4999: Train Loss = 1.8826, Validation Loss = 2.0042
Summary of Training Steps#
Hyperparameters: We specify parameters like
batch_size
,context_length
,max_steps
, andlearning_rate
to control training behaviour. Adjusting these can impact how quickly and effectively the model learns.Device Selection: Using a GPU (if available) allows the model to process data much faster than a CPU.
Model Initialisation: We create an instance of
GPTLanguageModel
and move it to the selected device.Optimiser Setup: The optimiser, here
AdamW
, adjusts model parameters during training.AdamW
is a variation of the popular Adam optimiser and helps control overfitting.Main Training Loop:
Evaluation: Every
eval_interval
steps, the model’s performance is assessed usingestimate_loss
. This helps monitor learning progress and adjust parameters if necessary.Batch Processing: For each training step, we get a new batch of data using
get_batch
.Forward Pass: The model processes the batch, producing
logits
(predictions) for each input sequence.Loss Calculation: The
calculate_loss
function compares the model’s predictions to the true labels, giving a measure of how well the model is performing.Backward Pass: We compute gradients to see how to adjust model parameters.
Optimiser Step: The optimiser updates the model’s weights based on the computed gradients, gradually reducing the loss.
At the end of training, we observe that both the training and validation loss gradually decrease, reaching below 2 within 5000 steps. This steady reduction in loss suggests that our language model is learning the task effectively. For further improvement, you may increase the number of steps, which could reduce the loss even more and potentially improve the model’s performance on complex or nuanced text.
A lower loss indicates that the model is becoming more accurate at predicting the next word or token in a sequence, so continued training can refine this ability—though it’s essential to monitor the validation loss to avoid overfitting (when the model becomes too specialised on the training data at the expense of generalisation to new data).
Generating Text#
Now that our model is trained, let’s explore generating text. We can start with an empty context, represented by torch.zeros((1, 1))
, which corresponds to a newline character (\n
). By setting max_new_tokens=250
, we instruct the model to generate a sequence of 250 characters.
# Generate text from the model with an empty initial context
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = gpt_net.generate(context, max_new_tokens=250)[0].tolist()
print(decode_txt(generated_text))
Whacence theirs no to not! makle
This fit welse of or a brom wery repits. Were'er caross,
I did that my leave Is for me, and would, me it vencess!
AUDDY ARbawnly sweath pouch'd to dears, bendert my were can nows,
The sun she so be high like you like
We can also provide a custom context to see how the model continues a given text passage. Here, we supply a short poetic passage as a prompt:
# Define a custom context for text generation
context = """
O mighty mind, in circuits vast and deep,
Thou dost all knowledge in thy logic keep.
With language broad, and understanding high,
Thou speaks as man, though but an artful lie.
"""
# Encode the custom context and reshape it to match input format
context = torch.tensor(encode_txt(context))
context = torch.unsqueeze(context, dim=0)
context = context.to(device)
# Generate text based on the provided context
generated_text = gpt_net.generate(context, max_new_tokens=250)[0].tolist()
print(decode_txt(generated_text))
O mighty mind, in circuits vast and deep,
Thou dost all knowledge in thy logic keep.
With language broad, and understanding high,
Thou speaks as man, though but an artful lie.
DUKE OF OMERCSAPLAMIBELLA:
That for my which all sore-till'st that age so weptorbed have hums at be incciation, and you brannt my gnot have any love idserp mee even, the whey haves hour parittiz?
TARD III:
U, so wore but me, would which you ran the
Discussion#
In this notebook, we built a simple language model that mimics, at a basic level, the behaviour of large language models (LLMs) like ChatGPT and Gemini. Although our model is simpler, it follows similar principles, using transformers and self-attention mechanisms to generate text based on learned data patterns.
Exercises#
Here are some exercises to further explore and enhance your understanding:
Experiment with different prompts: Try providing various types of starting contexts and see how the model responds. Does it produce coherent and contextually relevant text?
Adjust
max_new_tokens
: Observe how changing the length of generated text impacts the model’s output quality and coherence.Increase model capacity: Experiment with changing the model architecture, such as adding more transformer layers or heads. Monitor how these changes impact training time and the quality of generated text.