Perfect, Rajeev 🌟 — this is where you’ll truly understand how GPT learns —
not just how it generates text, but how it’s trained to become intelligent.
If Day 5 taught you how GPT predicts the next token,
then Day 6 teaches you how GPT learns to predict the next token accurately — through Pretraining and the Next-Token-Prediction Objective.
We’ll break this into a full intuitive explanation → with visual reasoning → practical training loop → and deep insights into why this simple learning rule creates language understanding.
🌎 DAY 6 — Pretraining & Next-Token Prediction Objective
(How GPT Learns to Understand, Think, and Generate Language)
🧠 1️⃣ The Big Picture: How Does GPT Learn?
GPT is trained on massive text corpora — internet articles, Wikipedia, books, code, forums, etc.
During pretraining, it repeatedly sees snippets of text and learns to predict the next token at each position.
That’s it.
No explicit grammar rules, logic rules, or labeled datasets.
Just:
“Given the past text, guess the next word.”
This single, simple rule — applied billions of times — teaches GPT everything:
grammar, knowledge, reasoning, humor, even coding.
🧩 2️⃣ The Training Objective — Next Token Prediction
🔹 Formal Definition
For every training example (a sequence of tokens):
[
x = [x_1, x_2, x_3, …, x_T]
]
GPT learns to model:
[
P(x) = P(x_1) P(x_2 | x_1) P(x_3 | x_1, x_2) … P(x_T | x_{<T})
]
The model’s loss is the negative log-likelihood over the sequence:
[
\mathcal{L} = -\sum_{t=1}^T \log P(x_t | x_{<t})
]
So every token tries to predict its successor correctly.
That’s why it’s called causal (autoregressive) language modeling.
🧮 Example
Text: "The sky is blue"
Training samples:
| Input (context) | Target (next token) |
|---|---|
| The | sky |
| The sky | is |
| The sky is | blue |
Loss = how wrong the model’s predicted probabilities are compared to the actual next word.
⚙️ 3️⃣ The Training Loop (Simplified)
for batch in dataset:
inputs, targets = batch[:, :-1], batch[:, 1:]
logits = model(inputs) # Predict next token for each position
loss = cross_entropy(logits, targets)
loss.backward()
optimizer.step()
💡 Key idea:
- Inputs = tokens 1 → n-1
- Targets = tokens 2 → n
- The model learns to minimize how far its predictions are from real next tokens.
🧩 4️⃣ Where Does the Data Come From?
GPT is trained on a mixture of diverse, cleaned datasets, including:
- Common Crawl (large internet snapshot)
- Wikipedia (detailed factual text)
- Books Corpus (story / narrative language)
- WebText (OpenAI’s filtered web data)
- GitHub code (for code models like Codex)
Each text is split into sequences (e.g., 2 048 tokens long)
and fed into the model during training.
🧩 Data Size vs Model Scale
| Model | Parameters | Tokens Trained On |
|---|---|---|
| GPT-2 | 1.5 B | 40 GB text (~8 B tokens) |
| GPT-3 | 175 B | 570 GB text (~300 B tokens) |
| GPT-4 | ~1 T | >1 Trillion tokens |
The more tokens → the better generalization (up to a point).
🔥 5️⃣ Why Next-Token Prediction Works So Well
Even though GPT only predicts the next word, this objective forces it to learn:
| Skill | Why It Emerges |
|---|---|
| Grammar & syntax | Must predict valid sequences |
| Semantics & meaning | Must choose words that make sense |
| World knowledge | Must understand context to predict accurately |
| Reasoning | Must track long logical dependencies |
| Memory | Must attend to past context via self-attention |
It’s emergent learning — patterns arise naturally from prediction pressure.
⚙️ 6️⃣ Autoregressive Training Visualization
Let’s visualize a small example:
"The cat sat on the mat."
During training:
Input: [The]
Target: [cat]
Input: [The, cat]
Target: [sat]
Input: [The, cat, sat]
Target: [on]
Loss is computed for every token prediction in parallel (shifted by one).
So GPT simultaneously learns from all possible next-token pairs.
🧮 7️⃣ Loss Function: Cross Entropy
Each step, GPT predicts a probability distribution over the vocabulary:
P("blue") = 0.79
P("dark") = 0.12
P("banana") = 0.01
If the correct token is “blue,” loss = −log(0.79).
Minimizing this loss pushes “blue” probability higher next time.
⚙️ Implementation in PyTorch
import torch
import torch.nn.functional as F
# logits: [batch, seq_len, vocab_size]
# targets: [batch, seq_len]
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
This trains all positions in the sequence simultaneously —
each predicting its next token.
⚙️ 8️⃣ Training Dynamics — How the Model Improves
1️⃣ At first, predictions are random → loss ≈ log(vocab size)
2️⃣ As weights update, probabilities for correct tokens rise
3️⃣ Model starts learning syntax (“subject → verb → object”)
4️⃣ Then semantics (“The capital of France is” → “Paris”)
5️⃣ Finally patterns of reasoning, coding, storytelling emerge
🧩 9️⃣ Learning Representations Inside the Model
As GPT trains:
- Early layers learn local features (spelling, word order)
- Middle layers learn syntactic relations (grammar, phrase structure)
- Deeper layers learn semantic & world knowledge
Each token’s vector (embedding) becomes a compressed summary of its context —
that’s why GPT embeddings are so powerful for downstream tasks.
🧠 🔟 How Attention Aids Learning
During training, attention heads discover useful relationships:
- “it” → nearest noun
- “because” → clause relationship
- “dog” → adjective “furry”
- “if…then” → logical pairing
These relationships are reinforced when the next-token loss rewards correct dependencies.
⚙️ 11️⃣ Optimization & Hardware
- Optimizer: AdamW
- Batch size: often 2 K–8 K sequences
- Learning rate scheduling: Warmup + Cosine decay
- Hardware: tens of thousands of GPUs/TPUs
- Duration: weeks to months
GPT-3 reportedly took ~355 GPU years of compute to pretrain.
🧩 12️⃣ Checkpointing & Gradient Accumulation
Since full batches can’t fit in memory, training uses:
- Gradient accumulation (over mini-batches)
- Checkpointing (recomputing layers on the fly to save RAM)
- Mixed precision (float16/8 quantization)
This enables training of trillion-parameter models efficiently.
🔍 13️⃣ From Pretraining → Fine-tuning
After pretraining (on general text), GPT is fine-tuned for specific behavior:
| Phase | Goal |
|---|---|
| Supervised Fine-tuning (SFT) | Teach to follow instructions |
| RLHF (Reinforcement Learning from Human Feedback) | Align with human preferences |
| Continual Training | Domain adaptation (e.g., coding, legal, medical) |
We’ll cover SFT + RLHF in Day 8–9.
🧩 14️⃣ Visual Summary: How GPT Learns
[Raw Text Dataset]
↓
Tokenization
↓
Token IDs
↓
Model Input
↓
Predicted Next Token Probabilities
↓
Cross-Entropy Loss
↓
Backpropagation
↓
Weight Updates
↓
Repeat for Trillions of Tokens
Over time, the model becomes a master of predicting language — and therefore understanding it.
⚡ 15️⃣ Emergent Understanding — Why It Works So Well
Next-token prediction indirectly teaches:
- Language structure (syntax & grammar)
- World knowledge (co-occurrence patterns)
- Logic & causality (“if X then Y”)
- Common sense reasoning
- Abstract representation learning
When you scale model + data + compute → you cross the emergence threshold where reasoning appears.
✅ 16️⃣ Summary — What You Learned Today
| Concept | Description |
|---|---|
| Pretraining | Massive unsupervised training on text data |
| Objective | Predict next token given previous context |
| Loss | Cross-entropy (minimize prediction error) |
| Outcome | Learns language patterns and world knowledge |
| Why it works | Predictive pressure forces compositional understanding |
# %% [markdown]
# Day 6 — Mini-GPT Next-Token Prediction (Colab Notebook)
This notebook trains a tiny autoregressive model (Mini-GPT) on a toy corpus so you can see **next-token prediction** in action end-to-end: tokenization, training loop, loss plot, and generation sampling.
**Notes:**
- Designed to run quickly in Colab (CPU or GPU).
- Uses PyTorch and matplotlib for plotting.
- Character-level tokenizer is used for simplicity.
---
# %% [markdown]
## 0. Install / Import Dependencies
Run this cell first. If running on Colab, enable GPU via `Runtime → Change runtime type → GPU`.
---
# %%
# Install (uncomment if running in a fresh Colab env)
# !pip install -q torch torchvision
import math
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import clear_output
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
# %% [markdown]
## 1. Toy Corpus & Character Tokenizer
We'll create a small dataset of sentences. We'll build char-level tokenization (stoi, itos) for simplicity.
---
# %%
corpus = [
"hello world",
"hello there",
"how are you",
"i love ai",
"gpt predicts next",
"the cat sat on the mat",
"the dog lay on the rug",
"a quick brown fox",
"pack my box with five dozen liquor jugs",
"sphinx of black quartz judge my vow"
]
text = "\n".join(corpus)
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
print('Vocab size:', vocab_size)
print('Chars:', chars)
print('\nSample text:\n', text[:200])
# Encode entire text as sequence of ints
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
# Create train sequences (sliding window)
seq_len = 32
inputs = []
targets = []
for i in range(0, len(data) - seq_len):
inputs.append(data[i:i+seq_len])
targets.append(data[i+1:i+seq_len+1])
inputs = torch.stack(inputs)
targets = torch.stack(targets)
print('Dataset shape (num_seq, seq_len):', inputs.shape)
# %% [markdown]
## 2. Mini GPT Model (tiny) — masked autoregressive with causal mask
This is a minimal decoder-only transformer with one or few blocks. It's intentionally small so it trains fast on this tiny corpus.
---
# %%
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_hidden):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.ln1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_hidden),
nn.GELU(),
nn.Linear(ff_hidden, embed_dim)
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x, attn_mask=None):
attn_out, attn_weights = self.attn(x, x, x, attn_mask=attn_mask)
x = self.ln1(x + attn_out)
ff_out = self.ff(x)
x = self.ln2(x + ff_out)
return x, attn_weights
class MiniGPT(nn.Module):
def __init__(self, vocab_size, embed_dim=64, n_layers=2, num_heads=4, ff_hidden=256, max_len=128):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.pos = nn.Parameter(torch.zeros(1, max_len, embed_dim))
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads, ff_hidden) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
self.max_len = max_len
def forward(self, idx):
b, t = idx.shape
assert t <= self.max_len, 'sequence length exceeds model max_len'
x = self.embed(idx) + self.pos[:, :t, :]
attn_weights_all = []
causal_mask = torch.triu(torch.ones(t, t, device=idx.device), diagonal=1).bool()
for block in self.blocks:
x, attn_weights = block(x, attn_mask=causal_mask)
attn_weights_all.append(attn_weights)
x = self.ln_f(x)
logits = self.head(x)
return logits, attn_weights_all
model = MiniGPT(vocab_size).to(device)
print(model)
# %% [markdown]
## 3. Training Setup
We'll use AdamW, cross-entropy loss, and a simple training loop with periodic sampling and loss plotting.
---
# %%
batch_size = 16
dataset_size = inputs.shape[0]
print('Num sequences:', dataset_size)
def get_batch(i, batch_size=batch_size):
start = i * batch_size
end = start + batch_size
if end > dataset_size:
end = dataset_size
start = max(0, end - batch_size)
return inputs[start:end].to(device), targets[start:end].to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)
epochs = 30
all_losses = []
# %% [markdown]
## 4. Training Loop — Train the MiniGPT
This loop prints loss, plots loss curve, and generates sample text every few steps.
---
# %%
model.train()
for epoch in range(epochs):
perm = torch.randperm(dataset_size)
epoch_loss = 0.0
steps = 0
for i in range(0, dataset_size, batch_size):
idx = perm[i:i+batch_size]
xb = inputs[idx].to(device)
yb = targets[idx].to(device)
optimizer.zero_grad()
logits, _ = model(xb)
loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
steps += 1
avg_loss = epoch_loss / steps
all_losses.append(avg_loss)
# display progress
clear_output(wait=True)
print(f'Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}')
# plot loss
plt.figure(figsize=(6,3))
plt.plot(all_losses, label='train loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid(True)
plt.legend()
plt.show()
# sample generation
model.eval()
with torch.no_grad():
# start from random seed string
seed = random.choice(corpus)[:10]
context = torch.tensor([[stoi[c] for c in seed]], dtype=torch.long).to(device)
generated = seed
for _ in range(80):
logits, _ = model(context)
# take last token logits
last = logits[0, -1, :]
probs = F.softmax(last, dim=-1)
idx = torch.multinomial(probs, num_samples=1).item()
ch = itos[idx]
generated += ch
# append to context (keep last seq_len tokens)
new_ctx = torch.cat([context, torch.tensor([[idx]], device=device)], dim=1)
if new_ctx.shape[1] > seq_len:
new_ctx = new_ctx[:, -seq_len:]
context = new_ctx
print('\nSample generation:')
print(generated)
model.train()
# %% [markdown]
## 5. Inspecting Attention Weights
We can run a single forward pass and visualize attention matrices (one head averaged) for a sample input. This helps see "who attends to whom".
---
# %%
model.eval()
with torch.no_grad():
sample_text = 'the cat sat on the mat'
idx = torch.tensor([[stoi[c] for c in sample_text]], dtype=torch.long).to(device)
logits, attn_weights_all = model(idx)
# attn_weights_all: list of [batch, num_heads, seq_len, seq_len]
# We'll average heads to get one matrix per layer
for layer, attn in enumerate(attn_weights_all):
# attn shape: (batch, num_heads, seq_len, seq_len)
attn_avg = attn[0].mean(dim=0).cpu().numpy()
plt.figure(figsize=(5,4))
plt.imshow(attn_avg, cmap='viridis')
plt.title(f'Layer {layer} — Avg attention (heads averaged)')
plt.colorbar()
plt.xlabel('Key position')
plt.ylabel('Query position')
plt.show()
# %% [markdown]
## 6. Save & Export Model (Optional)
You can save the trained model weights to drive if running on Colab.
---
# %%
# torch.save(model.state_dict(), 'minigpt.pth')
# print('Saved minigpt.pth')
# %% [markdown]
## 7. Exercises
1. Increase `epochs` to 100 and observe the loss curve — does the model overfit?\
2. Replace char-level tokenizer with a simple word-level tokenizer and retrain.\
3. Try different sampling strategies: greedy, top-k, temperature scaling.\
4. Print logits and top-k tokens at each generation step to understand probability distribution.
---
# %% [markdown]
## End of Notebook
This mini notebook demonstrates next-token prediction training on a toy corpus end-to-end. Use it to experiment and observe how the model learns to produce coherent text from a tiny dataset.