ai engineering
TL;DR: Covers the core stack of modern AI engineering: vector embeddings and similarity search, tokenization, transformer architecture, KV caching, and RAG. Notes from Gaurav Sen’s AI course with code throughout.
Speedrunning Gaurav Sen’s AI course over the weekend to see what I can learn
Prerequisites : python, math
vector embeddings fundamentals
how are vectors constructed
How objects are converted into vectors? Say you have a PDF file (or any textual document).
How do you describe its properties to a computer that only understands math? Through numbers.
Look at every document and plot it in 2D space
X axis -> length of document
Y axis -> how real the document is (fictional, non fictional)
[File] -> [2.6, 3.5, 1.5, 8.2, 6.7] (Spatial Domain)

Problem? 2D isn’t enough to represent all our data
Say we take a 3rd dimension -> popularity, 4th dimension -> year it was written in etc..
We end up with infinite number of dimensions to represent every single document..
- Constraint on the number of dimensions (D) (LLAMA (language model by Meta)-> 4096)
How to map infinitely complex object -> fixed number of dimensions? Compression?
Sounds like something neural networks do..


So 100 documents will have 100 vectors -> which can be represented in a n-dimensional space.
Vectors closer to other vectors will be nearby in the space:
- Harry Potter & Dan Brown docs near each other
- Research papers are near one another, and so on…
Similar objects will have similar/nearby distance (Euclidean distance) to each other.
So if two objects which are similar are nearby -> you reward the model else you penalise it.
What do we mean by “reward” and “penalise” here?
This works on Backpropagation. More on this in my blog.
Hence the model learns what objects are close to each other & what objects are far away when it comes to similarity. Look up this beautiful blog on vector embeddings for more.
Key Concept: you’re squishing complex objects into a fixed number of dimensions and similar things (Harry Potter, Lord of the Rings) end up close together in vector space. That’s what makes similarity search work.
Sidenote: how are we penalising the output? Gradient Change (explained below)
Penalty
When the model makes a mistake - for example, it calculates that a Harry Potter book is similar to a quantum physics paper - we don’t literally “punish” it. Instead, we have a mathematical function called a Loss Function. This function measures how far off the model’s prediction is from the actual, desired output.
If two dissimilar objects are close together in the vector space, the loss function will produce a high value. This high value is the “penalty.”

The red dot is incorrectly placed near the blue cluster, resulting in a large error or “penalty”.
How do we use this penalty to fix the model? Backpropagation and gradient descent.
-
Backpropagation calculates the gradient of the loss function. The gradient is a vector that points in the direction of the steepest increase in the loss.
-
Gradient Descent takes this gradient and moves the model’s weights (the numbers that determine the vector positions) in the opposite direction. This is the direction of the steepest decrease in loss.
For the full chain rule derivation and the calculus behind backpropagation, see the math post.
Think of it like a ball rolling down a hill. The gradient tells you which way is “up,” so you move the ball “down” to get to the bottom of the valley (the minimum loss).
Visualized on a loss surface:

gradient descent rolling down the loss surface to find the minimum
“Rewarding” just means the loss is low when predictions are right (Harry Potter near Dan Brown). Gradient is small, weights barely change. Model keeps doing what it’s doing.

The red dot has moved to its correct position. The loss is now low, and the gradient is small.
Calculate loss, determine gradient, update weights. Repeat millions of times during training. That’s how the model learns relationships between objects in n-dimensional space.
Once you have these vector embeddings of similar words being nearby, you can do fun stuff with it.
Example -> If you have man and woman (which are obviously nearby in this space) and you have plotted the point king in this same dimension, you can quickly do king - man + woman to get a NEW word -> queen
Illustrations of how nearby objects relate



Words with similar meanings live near each other in vector space. Vector arithmetic captures semantic relationships: king - man + woman = queen.
Summary:
- Vector is an ordered array of numerical values representing features/data points/parameters in a multi-dimensional space
- Similar objects generate vectors such that they are closer when plotted in this space
- Train baseline model/neural network using loss functions OR fine-tune the neural network to improve performance
vector databases & storage
where should we store these vector embeddings/lists/arrays?
AKA what database to pick?
Why do we need additional vector database when we have NoSQL (JSON like query) / Relational DB (SQL) ??
VectorDB: Optimized for storing high-dimensional vectors - Store these vectors optimally and also SEARCH for these vectors efficiently
SQL will create indexes (B/B+Trees) which will be compute expensive + they are ideal for sorted data (1D data) not for similar data in multiple dimensions, similarily for noSQL you’ll have a lot of json queries which will be hard to search (compute expensive)
Vector DB can cluster vectors such that finding nearby vectors is very quickly
They also compress data inside them.
Vector DB won’t work for consistent systems (bank) but we don’t need consistency in vector DB, all we need to do is map similar vectors nearby (king and queen nearby) & just search for them as fast as we can (which it does)
Think of netflix recommendation system, no need for consistency or availability, all we need to see is if I like Silicon Valley TV show, what other tech shows (IT company) are like silicon valley and how fast can I get them (which vector DB does really well)
Some famous vector DB:
- pgVector
- Milvus DB
- vertex.AI
- openSearch

so what makes VectorDB special, how does it compress high-dimensional vectors efficiently & search efficiently.
How does vector search work? Every document has (x,y) coordinate in the vector space (assumption: 2D dimension space)
All relevant documents to this query have similar (x,y) coordinates. Example: Harry Potter and Dan Brown will have similar (x,y) coordinates.
- How does the (x,y) distance comparison happen?
Harry Potter = (x1,y1)
Dan Brown = (x2,y2)
The Euclidean distance is given by $d = \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}$.
if d < threshold:
relevant document
else:
not relevant
Brute Force complexity: O(N) where N is the number of documents
How to speedup search queries?
- Compression (every vector that is of N dimension and size 2 bytes) so you need 2N bytes/vector : You need to fit more vectors in RAM.
2 things:
1: Product Quantization
- Break each vector into sub vectors & for each sub-vector find nearest centroid and store only index of centroid.
What does storing of centroid mean??
Storing only the index of the centroid means that instead of saving the actual values of the centroids, you store an identifier that points to the centroid’s location in memory or data structure
Instead of storing 128 floats (512B) -> store 8 integer (8 bytes) -> 64x reduction

2. Scalar Quantization
- Compress each float individually, convert from 32 bit float -> 8 bit float
Original HighD vectors -> capture fine-grained relationships, after compression very close neighbors will still match correctly (like king, queen) but slightly farther (yet relevant vectors) get noisy (like king, pawn)
We can deal with this much amount of error (they don’t affect top-1 accuracy but kill recall {how good the model identifies where the vector is}) {but it’s bad for recommendations}

In the context of a chatbot, precision refers to the proportion of relevant responses generated by the chatbot out of all responses it provided. Recall measures the proportion of relevant responses that the chatbot successfully identified from the total number of relevant responses that could have been given. Accuracy represents the overall correctness of the chatbot’s responses, calculated as the ratio of correctly identified responses to the total number of responses made.
vector search?
Vector databases deal with multiple dimensions. Sorting across all dimensions is not possible. Binary search cannot be applied in a high-dimensional vector space.
Instead, vector databases use specialized techniques to find approximate nearest neighbors. This includes building spatial partitioning structures like HNSW or IVF.
K-D tree: allows search in k-dimensional space (what we need)
Can split trees based on different dimensions, works same as BST if you have taken an algorithms class.

Drawback: space (lots of branching) + not accurate (root is heavily branched)
HNSW: Hierarchical Navigable Small Worlds
- Every vector is connected to the k-nearest points to it. (Outliers are connected to some points, center vectors (popular) are connected to all other vectors)
- Popular vectors are advanced to next level (fewer vectors that are good representation of underlying space)
Query hits the root -> finds nearby vector to root, if found -> OK, else you go deeper into the layer.


HNSW search - query starts at the top layer and drills down to find nearest neighbors
Drawback:Large scale: doesn’t work,
So what does work?
- Inverted File Index; doc vectors -> clusters (kmeans) -> each cluster has a centroid
- Find nearest centroid to that query (input query) :: Search k nearest clusters
IVF (Inverted File) is GPU-friendly because you can process clusters in parallel on GPUs which is exactly what GPUs are good at. HNSW is more CPU-friendly because it’s graph traversal which works better with CPU caches and lower memory overhead.

In IVF (Inverted File Index), distance is first calculated based on the centroid closest to the query vector and not the actual vector points themselves. This means the search might miss the true nearest neighbor if it falls into a different cluster. As a result, IVF can return a point that’s actually farther from the query in real vector space.

Increasing the number of clusters improves accuracy: with more clusters, the chances are higher that the actual closest point to the query ends up in a nearby or selected cluster, leading to better search results.
Search query vector can fail when closest point to query is assigned to centroid further than the nearest centroid.
Revisiting Entire Flow
Input: some english text -> query.
Step 1: Query is encoded into a high dimensional vector (query vector) where similar items are close to each other in the space.
Step 2: Search the index: Can’t use BS on sorted column, we use ANN (approximate nearest neighbor indexes) like IVF and HNSW.
Step 3: Score + Rank Each candidate vector compute similarity score using Euclidean Dist, Cosine Similarity or Dot product -> return top-K closest vectors.
Step 4: Post-processing + Filter (Language=English, Published:post2009, category:fiction) these filters are applied after vector scoring so we can get our result.

How does it work in production grade systems?
Milvus DB::


The first diagram shows Milvus’s high-level architecture: a shared-storage design where the access layer handles client requests, the coordinator service manages cluster topology and load balancing, and the worker nodes handle the actual indexing and querying of vectors. The second diagram breaks down the data flow - how vectors get inserted, indexed, and searched across distributed nodes. More on this here
programming section
Q. Store + Retrieve vectors in pgvector database
# Storing and retrieving vectors in pgVector DB
# install postgresSQL
'''
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE embeddings (
id SERIAL PRIMARY KEY,
text TEXT,
embedding VECTOR(768) # BERT uses 768 dimensions
);
'''
# pip install psycopg2-binary
import psycopg2
import numpy as np
connection = psycopg2.connect(host="localhost", dbname='db', user='samit', password='westsidegunn')
cursor = connection.cursor()
# main part
embedding = np.random.rand(768).astype(np.float32).tolist()
text = "hi this is samit"
cursor.execute("INSERT INTO embeddings (text, embedding) VALUES (%s, %s)", (text, embedding))
# end
connection.commit()
# retrieve vectors by ID
cursor.execute("SELECT embedding FROM embeddings WHERE id = %s", (1, ))
vector = cursor.fetchone()[0]
print(vector) # returns floats
# similarity search: heart of vector
query_vec = np.random.rand(768).astype(np.float32).tolist()
cursor.execute(
"""
SELECT id, text, embedding <=> %s AS distance
FROM embeddings
ORDER BY embedding <=> %s
LIMIT 5;
""", (query_vec, query_vec))
rows = cursor.fetchall()
for r in rows:
print(r)
'''
Result for similarity vector:
Nearest neighbors:
(1, 'andrej karpathy', 0.239182)
(4, 'machine learning basics', 0.298522)
(3, 'deep learning intro', 0.330182)
'''
# BASED ON HNSW: pgVector supports this
# CREATE INDEX on embeddings USING hnsw(embedding, vector_cosine_ops) WITH (m=16, ef_construction=200)
cursor.close()
connection.close()
Q. Find all relevant objects of an input object with threshold distance from it less than T. Compress all vectors + run the clustering algorithm provided.
import numpy as np
import time
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
# 10,000 vectors, dim for each vector = 768
N = 10000
D = 768
vectors = np.random.rand(N, D).astype(np.float32)
print("Loaded vectors:", vectors.shape)
ids = np.arange(1, N + 1) # fake ids
def find_relevant(vectors, input_vec, threshold):
V = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
q = input_vec / np.linalg.norm(input_vec)
# cosine distance = 1 - cosine similarity
cosine_dist = 1 - np.dot(V, q)
idx = np.where(cosine_dist < threshold)[0]
return idx, cosine_dist[idx]
# query vector
input_vec = np.random.rand(D).astype(np.float32)
T = 0.30 # threshold
# search
t0 = time.time()
idx_orig, dist_orig = find_relevant(vectors, input_vec, T)
t1 = time.time()
print("Relevant docs (Original)")
print("IDs:", ids[idx_orig][:10], "...") # show first 10
print("Count:", len(idx_orig))
print("Time:", t1 - t0, "seconds")
# pca
TARGET_DIM = 64 # reduce 768 → 64
t0 = time.time()
pca = PCA(n_components=TARGET_DIM)
vectors_pca = pca.fit_transform(vectors)
input_vec_pca = pca.transform([input_vec])[0]
t1 = time.time()
print("PCA Compression")
print("Original Dim:", D)
print("Compressed Dim:", TARGET_DIM)
print("PCA Fit Time:", t1 - t0, "seconds")
# search on compressed pca'd vector
t0 = time.time()
idx_comp, dist_comp = find_relevant(vectors_pca, input_vec_pca, T)
t1 = time.time()
print("Relevant docs (Compressed)")
print("IDs:", ids[idx_comp][:10], "...")
print("Count:", len(idx_comp))
print("Time:", t1 - t0, "seconds")
# kmeans (clustering) on original
k = 10 # number of clusters
t0 = time.time()
kmeans_orig = KMeans(n_clusters=k, random_state=42, n_init='auto').fit(vectors)
t1 = time.time()
print("KMeans (Original)")
print("Time:", t1 - t0, "seconds")
# kmeans on compressed
t0 = time.time()
kmeans_comp = KMeans(n_clusters=k, random_state=42, n_init='auto').fit(vectors_pca)
t1 = time.time()
print("Means (Compressed)")
print("Time:", t1 - t0, "seconds")
'''
Output
Loaded vectors: (10000, 768)
Relevant docs (Original)
IDs: [ 123 425 188 ...]
Count: 57
Time: 0.022 seconds
PCA Compression
Original Dim: 768
Compressed Dim: 64
PCA Fit Time: 0.55 seconds
Relevant docs (Compressed)
IDs: [123 425 188 ...]
Count: 58
Time: 0.006 seconds
KMeans (Original)
Time: 5.3 seconds
KMeans (Compressed)
Time: 0.8 seconds
'''
Q. Using BERT API, show how vectors are mapped in 2DSpace, and object -> vector conversion algorithm (ngram + bloom filters + jaccard index)
"""
APIs
1) /embed → BERT embeddings
2) /project2d → 2D projection (PCA)
3) /ngram_vec → object → vector via (n-grams + Bloom filter)
4) /jaccard → Jaccard over Bloom vectors
"""
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
import hashlib
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
app = FastAPI()
# ------------------------ BERT ------------------------
MODEL = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL).eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def bert(texts):
"""Return L2-normalized BERT sentence embeddings."""
with torch.no_grad():
tok = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
out = model(**tok).last_hidden_state
mask = tok.attention_mask.unsqueeze(-1)
pooled = (out * mask).sum(dim=1) / mask.sum(dim=1)
emb = torch.nn.functional.normalize(pooled, p=2, dim=1)
return emb.cpu().numpy().astype(np.float32)
def pca_2d(X):
"""Project embeddings to 2D for visualization."""
X = PCA(50).fit_transform(X) if X.shape[1] > 50 else X
return PCA(2).fit_transform(X)
# ----------------- N-gram + Bloom filter -----------------
def ngrams(text, n=3):
"""Simple word n-grams."""
toks = text.split()
return [" ".join(toks[i:i+n])] if len(toks) < n else \
[" ".join(toks[i:i+n]) for i in range(len(toks)-n+1)]
def bloom(tokens, m=2048, k=4, seed=0):
"""Bloom filter vector: boolean bit array."""
bits = np.zeros(m, bool)
for t in tokens:
for i in range(k):
h = hashlib.sha256(f"{seed}_{i}_{t}".encode()).digest()
bits[int.from_bytes(h[:4], "big") % m] = True
return bits
def jaccard(a, b):
"""Approximate Jaccard using Bloom bitsets."""
inter = np.logical_and(a, b).sum()
union = np.logical_or(a, b).sum()
return inter / union if union else 0.0
# --------------------------- Schemas ---------------------------
class TextList(BaseModel):
texts: list[str]
class OneText(BaseModel):
text: str
n: int = 3
m: int = 2048
k: int = 4
class TwoTexts(BaseModel):
text_a: str
text_b: str
n: int = 3
m: int = 2048
k: int = 4
# --------------------------- Endpoints ---------------------------
@app.post("/embed")
def embed(req: TextList):
E = bert(req.texts)
return {"dim": E.shape[1], "embeddings": E.tolist()}
@app.post("/project2d")
def project2d(req: TextList):
return {"coords": pca_2d(bert(req.texts)).tolist()}
@app.post("/ngram_vec")
def ngram_vec(req: OneText):
bits = bloom(ngrams(req.text, req.n), req.m, req.k)
idx = np.nonzero(bits)[0].tolist()
return {"num_bits": len(idx), "bit_indices": idx[:200]}
@app.post("/jaccard")
def jacc(req: TwoTexts):
A = bloom(ngrams(req.text_a, req.n), req.m, req.k)
B = bloom(ngrams(req.text_b, req.n), req.m, req.k)
return {"approx_jaccard": float(jaccard(A, B))}
tokenization - from text to numbers
Before LLMs can do anything, text needs to become numbers. That’s tokenization.
You can’t just assign each character a number though. Modern tokenization algorithms balance efficiency (small vocabulary size) with meaning preservation (keeping word boundaries intact).

Real Time Demonstration of Tokenisation
tokenization in practice
text = "Hi I'm samit and I like computers"
tokens = tokenizer.encode(text)
# [464, 15592, 2746, 18616, 2420]
-
More tokens -> more computation / higher context length
-
Text -> Token -> TokenID conversion into vectors
-
Mapping each tokenID -> vector via embedding layer (lookup table like a hash-map)
Why token ID helps? Caching -> “catch”, “ing” has separate token ID.
“Catching” can be different with respect to different contexts (catching a ball has different meaning than catching a fish but both of them have some inherent meaning common)
catch-ing (do - ing) :: some verb and ing -> all this information is gained by the model while it’s training.
token_id = 464 # Token ID for "The"
embedding_vector = embedding_matrix[token_id]
# Result: a vector of 4096 floating-point numbers
Similarly every word (which has a token ID) is a vector of 4096 dimension remember this
Because we know similar vectors are closer to each other, we can use this information.
We can’t just use similarity (“I love you” and “You love I” have similar meanings in theory but completely different in reality, we need some positional awareness of these tokens)
Order of the tokens -> Adding positional encoding to the vectors help. (Where each token is)
positions = [0, 1, 2, 3, 4]
positional_embeddings = positional_encoding[positions]
input_embeddings = embeddings + positional_embeddings # our final input
positional embeddings
This is a great video to explain positional encodings
“I love you” → embeddings “You love I” → same embeddings, different meaning
Transformers have no idea about sequence order by default. Add position information into the input vectors.
input_to_transformer = word_embedding + positional_encoding
Each position number (0, 1, 2, 3… ) is converted into a bunch of sine/cos waves:
Long waves capture coarse position
Short waves capture fine position
Mixed waves uniquely identify every position
Think multiple clock hands rotating at different speeds - each timestamp combination is unique.
For a position pos and dimension i: \(\text{PE}[pos,\; 2i] = \sin\!\left( \frac{pos}{10000^{\,i / d_{\text{model}}}} \right)\)
\[\text{PE}[pos,\; 2i+1] = \cos\!\left( \frac{pos}{10000^{\,i / d_{\text{model}}}} \right)\](Even indices: sine, Odd indices: cosine)
Have different meanings because the position of why is different hence we add positional encodings to the input vector/token so model knows where the word actually is.

Code for positional encodings
import torch
import numpy as np
import torch.nn as nn
class PE(nn.Module):
def __init__(self, dimension, max_len=5000):
super().__init__()
# create the matrix
pe = torch.zeros(max_len, dimension)
# create position indices (0, 1, ..., max_len-1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# using the log exp trick for numerical stability
scaling_term = torch.exp(torch.arange(0, dimension, 2).float() * (-np.log(10000.0) / dimension))
# apply Sin to even indices, Cos to odd indices
pe[:, 0::2] = torch.sin(position * scaling_term)
pe[:, 1::2] = torch.cos(position * scaling_term)
# add batch dimension
pe = pe.unsqueeze(0)
# Register as buffer (not a parameter, but part of module state)
self.register_buffer("pe", pe)
def forward(self, x):
# Add position encoding to embeddings (sliced to current sequence length)
x = x + self.pe[:, :x.size(1)]
return x
Each token ID indexes into the embedding matrix, retrieving its corresponding embedding vector. If the model has 4096 dimensions, each token becomes a vector of 4096 floating-point numbers.
programming section
Q. Implementing BPE (Byte Pair Encoding) from Scratch
'''
How BPE learns
It repeatedly:
counts frequent symbol pairs
merges the most frequent pair into a new token
updates the corpus
'''
from collections import defaultdict
def count_pairs(words):
"""Count frequencies of all adjacent symbol pairs in the corpus."""
freq = defaultdict(int)
for w in words:
for a, b in zip(w, w[1:]):
freq[(a, b)] += 1
return freq
def merge(words, pair, new_sym):
"""Replace every occurrence of 'pair' with merged 'new_sym'."""
A, B = pair
out = []
for w in words:
i = 0
new = []
while i < len(w):
if i < len(w)-1 and w[i] == A and w[i+1] == B:
new.append(new_sym)
i += 2
else:
new.append(w[i])
i += 1
out.append(tuple(new))
return out
# ---------------------- BPE Training ----------------------
def learn_bpe(corpus, num_merges=10):
"""
Learn BPE merges from training corpus.
Corpus: list of words → turned into lists of characters.
"""
words = [tuple(w) for w in corpus]
merges = {}
for _ in range(num_merges):
pairs = count_pairs(words)
if not pairs: break
pair = max(pairs, key=pairs.get) # most frequent pair
new_sym = "".join(pair) # merge into single token
words = merge(words, pair, new_sym)
merges[pair] = new_sym
return merges, words
# ---------------------- BPE Inference ----------------------
def apply_bpe(text, merges):
"""Apply learned merges to new text."""
tokens = list(text)
for pair, new_sym in merges.items():
merged = []
i = 0
while i < len(tokens):
if i < len(tokens)-1 and (tokens[i], tokens[i+1]) == pair:
merged.append(new_sym)
i += 2
else:
merged.append(tokens[i])
i += 1
tokens = merged
return tokens
corpus = ["low", "lowest", "newer", "wider"]
merges, encoded = learn_bpe(corpus, num_merges=5)
print("Merges:", merges)
print("Encoded Corpus:", encoded)
test = "lowests"
print("Encoded Test:", apply_bpe(test, merges))
'''
Output:
Learned Merges: {('l', 'o'): 'lo', ('lo', 'w'): 'low', ('e', 'r'): 'er', ...}
Encoded Corpus: [('low',), ('lowest',), ('newer',), ('wider',)]
Encoded Test Text: ['low', 'e', 's', 't', 's']
'''
How this works:
- Start with individual characters as tokens
- Find the most frequent adjacent pair
- Merge that pair into a new token
- Repeat for num_merges iterations
- Apply learned merges to new text
Same algorithm GPT and other modern LLMs use for tokenization.
language models & transformers
LLM internals
How do these LLMs work?
Transformers power GPT, BERT, LLaMA, basically every modern LLM. Two key operations per layer:
- Multi-head self-attention: how one token relates to another
- Feed-forward networks: minimise the loss function so predicted text = ground truth (if you ask “What is state of water” the prediction tries to be solid, liquid, and gaseous as accurately as it can)
Enough of this during training and the model gets good enough to come up with its own answers based on all the text on the internet.
Stack many layers (GPT-3 has 96!) and the model progressively builds richer representations until it can predict what comes next.
self attention mechanism (how one word relates to itself and others)
“The crane flew away.” - you think bird.
“The crane lifted a car.” - you think machine.
Same word, different meaning based on context. Self-attention lets the model figure this out.
Q, K, V matrices computed from tokens, scaled dot-product scores, softmax weights, and the final weighted output.
Every word looks at every other word and decides how important each one is. That importance score builds a better representation of the word that captures its role in the full sentence.

Most important part of self attention: 3 Matrix: Query, Key, Value
some analogies for q,k,v matrices
- Query: What is my question (What is the highest salary of Software Engineer)
- Key: What is the parameter for my question (in Microsoft)
- Value: What is my answer (300k+ (lol))
- Query: what the current token is looking for e.g. a vowel may be looking for a consonant (in the case of character-level tokenization), a noun may be looking for verbs associated with it etc.
- Key: what I can offer i.e. what response a token offers for a query, what a token is (I am a verb, an adjective…)
- Value: forementioned matrices that encompasses the context, meaning and weight of each token (e.g. token 2 is valuable in predicting tokens 10, 15; token 2 has such an embedding, carries such weight in predicting tokens x and y…)
- Query: Spotlight projector
- Key: Shifting of spotlight from one place to another
- Value: When I shine the spotlight on a particular screen, what do I see?
Query tells you which word you’re seeing attention for and keys tell you how much that word affects this incoming word by (crane affected by wings but not by the word THE)… Then a matrix is formed between Query and Key by multiplying them:: dot_product(Q, K) for all words.
beautiful blog on construction of Q, K, V matrices
This is the representation of the matrix formed between Query and Key:

But how are these Query, Key, Value matrices calculated in the first place?
# Self-attention computation
Q = input @ W_query # Shape: [batch, seq_len, dim]
K = input @ W_key # Shape: [batch, seq_len, dim]
V = input @ W_value # Shape: [batch, seq_len, dim]
W_query, W_key, W_value are weight matrices which are learned during training (via backprop) {They start from random values}
This is how the model knows what english words mean and how it can figure out the meaning of the input text.
Attention mechanism: how much every word / token should attend to every other word / token.
How much importance to give to other words, how do the words relate to each other so we can figure out the meaning of the sentence.
\[S = \frac{QK^{\mathsf{T}}}{\text{scaling_factor}}\] \[\text{attention} = \mathrm{softmax}(S)\] \[\text{output} = \text{attention} \cdot V\]# Attention Score
score = (Q @ K.transpose()) / sqrt(dim) # divide to normalise incase output gets too big
attention_weights = softmax(scores) # also normalisation trick (anything you put inside here gets converted between range [0-1]
output = attention_weights @ V # whatever score u get -> dot product with Value to get your answer
Common Mistake: Forgetting to scale attention scores by √d_k leads to numerical instability. Without scaling, dot products can grow very large (e.g., 100+), pushing softmax into regions with near-zero gradients (vanishing gradient problem). This causes training instability and the model fails to learn long-range dependencies.
This process runs multiple times, depending on the model. GPT-3 had 96 layers with 96 heads like discussed before.
So this operation runs 96 times on our input. Everytime the model learns something new about the text. This happens parallely on GPUs hence it’s too fast.
For example:
“Hi I am samit and I like computers”
-
MHA1 -> user’s name is samit
-
MHA2 -> user likes computers
-
MHA3 -> I here refers to Samit
-
MHA4 -> samit likes computers
Here MHA is multi-head-attention block
It learns to focus on different aspects of the relationships between tokens
All these 96 outputs get concatenated and passed to a feed-forward network (two linear transformations and a non-linear activation function).
Linear transformations are matrix multiplications (think y = mx + c slope calculation), and non-linearity refers to ReLU/Leaky ReLU activation functions which take the output y: if y < 0 then answer = 0, else answer = y. This introduces non-linearity so the function can learn complex patterns beyond straight lines.
# ffn
hidden = activation(input @ W1 + b1) # Expand to 4x dimension (basically y = x * m + c)
output = hidden @ W2 + b2 # Project back down (same y = mx + c {input -> hidden -> output})
Attention Code
#https://www.deep-ml.com/problems/53
import numpy as np
import torch
def softmax(x):
# to make sure output are in range (0-1)
ex = np.exp(x - np.max(x, axis=-1, keepdims=True))
return ex / np.sum(ex, axis=-1, keepdims=True)
def compute_qkv(X, W_q, W_k, W_v):
# computer query, key, value vectors
Q = np.dot(X, W_q) # query = X (input) * Weight_Query (learnt via backprop (starts with random))
K = np.dot(X, W_k) # key = X (input) * Weight_Key (learnt via backprop)
V = np.dot(X, W_v) # key = X (input) * Weight_Value (learnt via backprop)
return Q, K, V
def self_attention(Q, K, V):
d_k = Q.shape[-1] # column of query : scaling val
attn_logits = Q @ K.T # Q : [3,2], K : [3,2] can't do matmul ((m, n) * (p, q) :: output matrix n == p but that's not the case here hence we do Transpose)
attn_logits = attn_logits / np.sqrt(d_k) # prevents overflow
attn_logits = softmax(attn_logits) # convert probability to range (0-1)
attention_output = attn_logits @ V # final ans -> dot product with Value vector
return attention_output
def causal_mask(seq_len):
# upper triangular strict: positions j > i should be masked when computing output at i
# this is a trick so that while inferencing, the model doesn't see all outputs at once
# while predicting "Hi my name is Samit" -> we should hide all further elements while it's guessing the word "my" {More explained later}
m = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=1)
return m # shape (T, T)
This is what the attention matrix looks like:

Multiply original vectors being pushed by the right amount of force (values in attention matrix)
This is what attention_logits @ V looks like

This attention matrix goes to -> Feed Forward Neural Network (information flows strictly in one direction, from input nodes through hidden layers (if any) to output nodes :: Basically it learns more features about the matrix and hence about the sentence)
recap
What happens when you press enter on the ChatGPT button?
S1: Create tokens for the sentence

S2: Find similarity using cosine vector of two vectors and create attention matrix

S3: Feed to feed-forward neural network
S4: Predict Logits (next word)

attention visualizations
Some visualizations:
How one word relates to another
![]()
Attention Matrix @ Value to get final matrix
— Hiding future words at inference / prediction time so model doesn’t cheat and predicts words which it has been trained to do
The LLM has two parts:
-
Encoder: Turns raw text into structured meaning by converting tokens into contextual vectors. It learns “what is being said.”
-
Decoder: Uses those vectors to generate the next tokens, predicting text step-by-step while referencing context. It learns “what to say next.”
Together: understanding → generation.
causal masking and why we need it:
The Problem: “Cheating” During Training
Say you’re translating “Hello how are you” to Spanish, word by word:
“Hello” → “Hola” “how” → “cómo” “are” → “estás” “you” → “tú”
During training, we already know the full target sentence is “Hola cómo estás tú”.
Without causal masking:
- When the model is trying to predict “cómo”, it can accidentally look ahead and see “estás” and “tú” in the target. This is cheating! The model learns to copy from the future instead of actually learning to translate.
With causal masking:
- When predicting “Hola” → Can only see: [nothing]
- When predicting “cómo” → Can only see: [Hola]
- When predicting “estás” → Can only see: [Hola, cómo]
- When predicting “tú” → Can only see: [Hola, cómo, estás]
Example:
Target: [Hola, cómo, estás, tú]
Mask (1 = can see, 0 = blocked):
Hola cómo estás tú
Hola [ 1 0 0 0 ] ← Predicting "Hola" sees nothing
cómo [ 1 1 0 0 ] ← Predicting "cómo" sees only "Hola"
estás [ 1 1 1 0 ] ← Predicting "estás" sees "Hola, cómo"
tú [ 1 1 1 1 ] ← Predicting "tú" sees all previous
Why “Causal”?
Because cause comes before effect. Each word can only see its causes (previous words), not its effects (future words).
Causal masking forces the model to learn left-to-right generation, just like it will have to do at inference time when it generates one word at a time without knowing what comes next
How causal masking works



the full transformer block: embeddings -> positional encoding -> encoder (self-attention + FFN) -> decoder (masked self-attention + cross-attention + FFN) -> output logits
programming section
Q. Coding Transformer Architecture along with intuitive comments of whatever explained above
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""
Adds positional information to token embeddings.
Why? Transformers have no inherent notion of token order (unlike RNNs).
We use sinusoidal functions so the model can learn to attend to relative positions.
Why sinusoidal? They allow the model to extrapolate to sequence lengths
not seen during training, and can represent relative positions as linear functions.
"""
def __init__(self, d_model, max_len=5000):
super().__init__()
# Create a matrix of shape (max_len, d_model) for positional encodings
pe = torch.zeros(max_len, d_model)
# Position indices: [0, 1, 2, ..., max_len-1]
position = torch.arange(0, max_len).unsqueeze(1) # (max_len, 1)
# Division term for the sinusoidal functions
# Creates different frequencies for each dimension
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# Apply sin to even indices, cos to odd indices
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension: (1, max_len, d_model)
pe = pe.unsqueeze(0)
# Register as buffer (not a parameter, but part of module state)
self.register_buffer("pe", pe)
def forward(self, x):
"""
Args:
x: Token embeddings of shape (batch_size, seq_len, d_model)
Returns:
x with positional encoding added
"""
# Add positional encoding to input embeddings
# We slice pe to match the sequence length
x = x + self.pe[:, :x.size(1)]
return x
def attention(q, k, v, mask=None):
"""
Scaled Dot-Product Attention.
Why scaled? Without scaling by sqrt(d_k), dot products can grow large,
pushing softmax into regions with tiny gradients (vanishing gradient problem).
Args:
q: Queries (batch, heads, seq_len, d_k)
k: Keys (batch, heads, seq_len, d_k)
v: Values (batch, heads, seq_len, d_k)
mask: Optional mask to prevent attention to certain positions
Returns:
output: Attention-weighted values
attn: Attention weights (useful for visualization)
"""
d_k = q.size(-1)
# Compute attention scores: how much should each query attend to each key?
# Shape: (batch, heads, seq_len, seq_len)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask (optional): set scores to -inf so they become 0 after softmax
# Used for padding tokens or causal (look-ahead) masking
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Convert scores to probabilities (they sum to 1 across keys)
attn = torch.softmax(scores, dim=-1)
# Apply attention weights to values
# This is a weighted sum: "I'll take this much from each value"
return attn @ v, attn
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention mechanism.
Why multiple heads? Different heads can learn to attend to different aspects:
- One head might focus on syntactic relationships
- Another on semantic meaning
- Another on long-range dependencies
This is like having multiple "representation subspaces" working in parallel.
"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension of each head
# Linear projections for Q, K, V
# Why? To learn different representations of the input
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
# Final output projection
self.out_proj = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
Split the last dimension into (num_heads, d_k).
Reshape from (batch, seq_len, d_model) to (batch, num_heads, seq_len, d_k)
Why? So each head operates on a smaller dimension independently.
"""
batch_size, seq_len, d_model = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, num_heads, seq_len, d_k)
def forward(self, q, k, v, mask=None):
"""
Args:
q, k, v: Query, Key, Value tensors (can be the same for self-attention)
mask: Optional attention mask
"""
# Linear projections and split into multiple heads
q = self.split_heads(self.q_proj(q))
k = self.split_heads(self.k_proj(k))
v = self.split_heads(self.v_proj(v))
# Apply attention for all heads in parallel
out, attn = attention(q, k, v, mask)
# Concatenate heads and apply final linear projection
# Reshape: (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
batch_size = out.size(0)
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.out_proj(out)
class FeedForward(nn.Module):
"""
Position-wise Feed-Forward Network.
Why? Adds non-linearity and capacity to transform the representations.
Applied to each position independently (same weights for all positions).
Why 2-layer MLP? Standard architecture - expand then compress.
The expansion (d_model -> dim_ff) allows the model to learn more complex functions.
"""
def __init__(self, d_model, dim_ff=2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, dim_ff), # Expand
nn.ReLU(), # Non-linearity
nn.Linear(dim_ff, d_model), # Compress back
)
def forward(self, x):
return self.net(x)
class EncoderLayer(nn.Module):
"""
Single Transformer Encoder Layer.
Architecture: Self-Attention -> Add & Norm -> Feed-Forward -> Add & Norm
Why Add & Norm?
- Residual connections (Add) help gradients flow during backprop
- LayerNorm stabilizes training and speeds convergence
"""
def __init__(self, d_model, num_heads, dim_ff=2048):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, dim_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
"""
Args:
x: Input tensor (batch, seq_len, d_model)
mask: Optional attention mask
"""
# Self-attention with residual connection
attn_out = self.attn(x, x, x, mask)
x = self.norm1(x + attn_out) # Add & Norm
# Feed-forward with residual connection
ff_out = self.ff(x)
x = self.norm2(x + ff_out) # Add & Norm
return x
class DecoderLayer(nn.Module):
"""
Single Transformer Decoder Layer.
Architecture:
1. Masked Self-Attention (on target) -> Add & Norm
2. Cross-Attention (target attends to encoder output) -> Add & Norm
3. Feed-Forward -> Add & Norm
Why masked self-attention? During training, we don't want the decoder to "cheat"
by looking at future tokens. The mask ensures position i can only attend to positions <= i.
Why cross-attention? This is where the decoder "reads" from the encoder output,
allowing it to use the source information when generating the target.
"""
def __init__(self, d_model, num_heads, dim_ff=2048):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, dim_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_out, tgt_mask=None, enc_mask=None):
"""
Args:
x: Decoder input (batch, tgt_len, d_model)
enc_out: Encoder output (batch, src_len, d_model)
tgt_mask: Causal mask for target sequence
enc_mask: Mask for encoder output (e.g., for padding)
"""
# Masked self-attention on target (causal)
out = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + out)
# Cross-attention: query from decoder, key & value from encoder
out = self.cross_attn(x, enc_out, enc_out, enc_mask)
x = self.norm2(x + out)
# Feed-forward
out = self.ff(x)
x = self.norm3(x + out)
return x
class Transformer(nn.Module):
"""
Full Transformer Model (Encoder-Decoder architecture).
Used for sequence-to-sequence tasks like translation, summarization, etc.
Args:
d_model: Dimension of embeddings and hidden states (typically 512 or 768)
num_heads: Number of attention heads (typically 8)
num_layers: Number of encoder and decoder layers (typically 6)
vocab_size: Size of the vocabulary
"""
def __init__(self, d_model=512, num_heads=8, num_layers=6, vocab_size=30522):
super().__init__()
# Token embeddings (converts token IDs to vectors)
self.embedding = nn.Embedding(vocab_size, d_model)
# Positional encoding (adds position information)
self.pos = PositionalEncoding(d_model)
# Stack of encoder layers
self.encoder_layers = nn.ModuleList(
[EncoderLayer(d_model, num_heads) for _ in range(num_layers)]
)
# Stack of decoder layers
self.decoder_layers = nn.ModuleList(
[DecoderLayer(d_model, num_heads) for _ in range(num_layers)]
)
# Final projection to vocabulary (to get logits for each token)
self.final_linear = nn.Linear(d_model, vocab_size)
def make_causal_mask(self, size):
"""
Creates a causal (lower triangular) mask for the decoder.
Why? Prevents the decoder from attending to future tokens during training.
Position i can only see positions 0 to i (not i+1, i+2, ...).
Example for size=4:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
"""
mask = torch.tril(torch.ones(size, size, dtype=torch.uint8))
return mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
def forward(self, src, tgt):
"""
Forward pass through the transformer.
Args:
src: Source token IDs (batch, src_len)
tgt: Target token IDs (batch, tgt_len)
Returns:
logits: Output logits (batch, tgt_len, vocab_size)
"""
# Convert token IDs to embeddings and add positional encoding
src = self.pos(self.embedding(src))
tgt = self.pos(self.embedding(tgt))
# Pass through encoder stack
enc_out = src
for layer in self.encoder_layers:
enc_out = layer(enc_out)
# Create causal mask for decoder (prevents looking ahead)
causal_mask = self.make_causal_mask(tgt.size(1)).to(tgt.device)
# Pass through decoder stack
dec_out = tgt
for layer in self.decoder_layers:
dec_out = layer(dec_out, enc_out, tgt_mask=causal_mask)
# Project to vocabulary size to get logits
logits = self.final_linear(dec_out)
return logits
# Example usage
if __name__ == "__main__":
# Create a small transformer for demonstration
model = Transformer(
d_model=512, # Embedding dimension
num_heads=8, # Number of attention heads
num_layers=6, # Number of encoder/decoder layers
vocab_size=10000 # Vocabulary size
)
# Example input (batch_size=2, sequence_length=10)
src = torch.randint(0, 10000, (2, 10)) # Source sequence
tgt = torch.randint(0, 10000, (2, 8)) # Target sequence
# Forward pass
output = model(src, tgt)
print(f"Input source shape: {src.shape}")
print(f"Input target shape: {tgt.shape}")
print(f"Output logits shape: {output.shape}") # (2, 8, 10000)
Output of 1st transformer -> input of second transformer
-
1st transformer: has some context of the sentence
-
2nd transformer: one order higher, able to understand semantics / sarcasm / crane wings -> crane is going to get food or escaping from predator (emotion)
-
3rd transformer -> so on.. Keep stacking
-
When you reach 96th transformer: extract all info from input and give good output.
-
Output of 96th transformer: prob of what the next word is going to be.
What word should be picked is called temperature -> second most or third most or maybe be greedy and pick the one with highest probability (this is model dependent)

recap: from internet to ChatGPT
Step 1: Tokenization.
Step 2: Embedding lookup. Each token ID indexes into the embedding matrix, retrieving its corresponding embedding vector. If the model has 4096 dimensions, each token becomes a vector of 4096 floating-point numbers.
Step 3: Add positional encodings
Step 4: The input embeddings flow through each transformer layer. For a 32-layer model, this happens 32 times.
Step 5: Generate first token. After the final layer, the hidden states get projected to vocabulary size through a linear layer, then softmax converts these logits to probabilities over all possible next tokens.
Step 6: Decode phase. Now we generate tokens one at a time. For each new token, we only compute fresh Q, K, V for that token, retrieving cached values for all previous tokens.
Step 7: Detokenization. Finally, the sequence of token IDs gets converted back to text using the tokenizer’s vocabulary. output_text = tokenizer.decode(generated_tokens)
This entire process repeats for every token generated, with the KV cache growing at each step. The decode phase continues until the model generates a stop token or reaches a maximum length limit.
So how is a model like ChatGPT actually created, from raw data to production assistant?
step 1: pre-training (building the base model)
Download and preprocess the internet
- Scrape text from web pages, books, papers, code repositories
- Clean and deduplicate data
- Tools: FineWeb for large-scale data
Tokenization
- Convert text → sequence of tokens using BPE
- Start with 256 byte-level tokens
- Iteratively merge most common pairs
- Visualize: TikTokenizer
Neural Network Training
- Input: Sequence of tokens
- Process: Transformer layers (attention + FFN)
- Loss: Predict next token correctly
- Backpropagation tunes weights and biases
- Train for weeks/months on thousands of GPUs

Inference/Prediction
- Predict one token at a time
- Generate probability distribution over vocabulary
- “Hi my name is Samit and this website is my home on the _”
- internet: 90%
- web: 8%
- blog: 2%
- Sample from distribution (temperature controls randomness)
Result: A “base model” - an internet document simulator that can continue text, but doesn’t follow instructions or avoid harmful outputs.
step 2: post-training (making it useful & safe)
The base model needs alignment - it has to follow instructions and not be harmful.
Supervised Fine-Tuning (SFT)
- Human labelers write example conversations
- Model learns to respond to instructions rather than just continue text
- Think: Base model = theory, SFT = worked examples
- Today: Much of this is LLM-assisted (humans edit AI-generated responses)
Challenges the base model has:
- Can’t count accurately
- Poor spelling (sees tokens, not letters)
- Needs tools for math/code execution
- May hallucinate facts

Reinforcement Learning from Human Feedback (RLHF)
- Generate multiple responses for each query
- Humans rank them (A vs B comparisons)
- Train a reward model to predict human preferences
- Use PPO to optimize the LLM toward high-reward outputs

Result: an assistant that follows instructions, uses tools when needed, and doesn’t go off the rails.
the complete flow
Raw Internet Data
↓
Tokenization (BPE)
↓
Pre-Training (Next Token Prediction)
↓
Base Model
↓
Supervised Fine-Tuning (SFT)
↓
Instruction-Following Model
↓
RLHF (Human Preference Optimization)
↓
ChatGPT / Claude / LLaMA-Chat
That’s how you go from “internet document simulator” to “useful AI assistant.”
How to improve these models? Models are better at evaluating than generating, so whatever answer you get goes back to the model to evaluate if this is correct or not.
Looking at a response, it’s much easier to evaluate the correctness of that response than to generate the response in the first place
Another way is prompt engineering -> giving the model more context on the side (significantly improves the model) basically giving it a certain personality (system prompt)
Few shot prompting: Give model couple of examples on how to answer questions, model knows how to behave -> quality of responses are much better and accuracy can be improved.

how do we bring down the costs of LLMs
Each query takes 2 cents (gpt) -> how to bring this down?
1) Caching: can’t cache ChatGPT (not our server)
2) Vector based caching.
Take input queries -> convert into vectors (capture meaning of queries) if prevs queries are similar, we can re-use output of previous queries as response (kind of like hashmaps work)
Key: vectors, Value: answers
- Example everyone asks what the weather is, commmon question: cache it so you don’t have to look it up each time
3) Choosing lower cost AI Model (DeepSeek, or smaller models like Mistral)
4) Bring down number of tokens : Smaller context windows, fewer tokens {can also do rate-limiting: query limits per user/session}
5) How to improve user experience.
- User asking what is the size of this tshirt -> but what user actually wants to know is if the t-shirt is going to fit them?
- Can use Agents for this. AI Agents looks at a customer, looks at the flow, and currently on the state they are at, recommends better answers (Swarm by OPENAI)
- Security / Privacy (data leakage) -> Model can save queries but ideally you want to anonymize them, credit card / email data -> Don’t store it. No personal ID information in system or database.
training beyond next-token prediction: reasoning & alignment
Modern LLMs don’t just predict the next token. They reason, follow instructions, and avoid harmful outputs.
the reasoning challenge
Language models are pattern matchers. They memorize well but struggle with actual reasoning. A chess model can memorize millions of games but can’t “think” about strategy for a novel position - it just pattern-matches similar positions it’s seen before.
So how do you teach a model to reason?
RLHF: reinforcement learning from human feedback
We can’t build true logical reasoning (if we could, we’d have AGI), but we can simulate reasoning behavior through reinforcement learning.
The chess analogy: Teach a model chess by giving it rewards:
- Checkmate in 2 moves → +100 points
- Checkmate in 4 moves → 0 points
- Checkmate in 6 moves → -100 points

Model learns to navigate the search space of possible moves, preferring paths that lead to better outcomes. That’s reinforcement learning.
Applying RLHF to LLMs:
For LLMs, we can’t give numeric rewards directly. Instead:
- Generate multiple responses to a query
- Have humans rank them (“Which is better: Option A or Option B?”)
- Train a reward model to predict human preferences
- Use reinforcement learning to steer the LLM toward high-reward outputs
Example:
- Query: “Write a blog post about AI”
- LLM generates 5 different blog posts
- Humans pick the best one
- Model learns what makes a “good” blog post

PPO (Proximal Policy Optimization): Human feedback is expensive and noisy. PPO solves this by taking small, conservative policy updates. Think of it like gradient descent for decision-making:
- If one response got +100 reward, nearby responses likely get +99, +98, etc.
- The model learns a smooth reward landscape from sparse human feedback
- Small steps prevent catastrophic forgetting
Limitations:
- Humans aren’t perfect evaluators (that’s why ChatGPT uses pairwise comparisons, not scoring)
- Human feedback is expensive to collect at scale
- Sometimes humans disagree or give malicious feedback
chain of thought (CoT): teaching models to “think”
Train models to show their reasoning steps instead of jumping to an answer.
Example without CoT:
- Q: “What is 17 × 23?”
- A: “391”
Example with CoT:
- Q: “What is 17 × 23?”
- A: “Let me break this down:
- 17 × 20 = 340
- 17 × 3 = 51
- 340 + 51 = 391”

By fine-tuning on CoT examples, models learn to decompose problems. We can enhance this further with RLHF:

Advanced technique: Generate multiple CoT reasoning paths, convert them to vectors, use majority voting, then train via RL on the best paths.
tree of thought (ToT): exploring multiple reasoning paths
Tree of Thought explores multiple reasoning branches simultaneously:
The process:
- Decomposition: Break problem into subgoals
- Branching: Generate multiple possible next steps
- Self-evaluation: Score each branch (accuracy, feasibility)
- Pruning: Discard low-scoring branches, backtrack if needed
- Consolidation: Combine the best paths into final output

This is similar to Monte Carlo Tree Search (MCTS) used in game-playing AIs - the model explores a tree of possibilities and picks the most promising path.
tool usage: when LLMs need help
LLMs struggle with:
- Precise arithmetic (109 ÷ 32 = ?)
- Real-time information (current weather, stock prices)
- Code execution (sorting an array)
- API calls (play a song on Spotify)
Solution: Give them tools!
During fine-tuning, models learn to recognize when they need external help:
- Math problem → call calculator tool
- “Sort this array” → execute Python code
- “Play music” → call Spotify API via MCP server

How it works: The model outputs a special token indicating tool use:
{
"tool": "calculator",
"input": "109 / 32"
}
The system intercepts this, runs the tool, and returns:
{
"result": 3.40625
}
The model then incorporates this into its response: “The answer is 3.40625”
Examples:
- LangChain: Framework for orchestrating tool calls
- Claude MCP servers: Anthropic’s Model Context Protocol for tool integration
- ChatGPT plugins: OpenAI’s tool ecosystem
This was huge in 2024 - LLMs went from text generators to things that can actually do stuff.
programming section
Q. Monte Carlo Tree Search Algorithm Code (Selection -> Expansion -> Simulation -> Backprop)
import math, random
# ===================== NODE =====================
class Node:
"""Single MCTS node."""
def __init__(self, state, parent=None, action=None):
self.state = state
self.parent = parent
self.action = action
self.children = []
self.untried = list(state.legal_actions())
self.wins = 0
self.visits = 0
def uct(self, c=1.4):
"""UCT score = exploitation + exploration."""
return (self.wins / (self.visits + 1e-8) +
c * math.sqrt(math.log(self.parent.visits + 1) /
(self.visits + 1e-8)))
# ===================== MCTS =====================
def mcts(root_state, iters=1000):
root = Node(root_state)
for _ in range(iters):
node = root
state = root_state.clone()
# ---- 1. SELECTION ----
while not node.untried and node.children:
node = max(node.children, key=lambda c: c.uct())
state = state.next_state(node.action)
# ---- 2. EXPANSION ----
if node.untried:
a = node.untried.pop()
state = state.next_state(a)
child = Node(state, parent=node, action=a)
node.children.append(child)
node = child
# ---- 3. SIMULATION (ROLLOUT) ----
while not state.is_terminal()[0]:
a = random.choice(state.legal_actions())
state = state.next_state(a)
winner = state.is_terminal()[1]
# ---- 4. BACKPROPAGATION ----
while node:
node.visits += 1
# win = opponent lost
if winner == -node.state.to_play:
node.wins += 1
node = node.parent
# best child = most visits
return max(root.children, key=lambda c: c.visits).action
# ===================== SIMPLE TIC-TAC-TOE STATE =====================
class TTState:
def __init__(self, board=None, to_play=1):
self.board = board[:] if board else [0]*9
self.to_play = to_play
def clone(self): return TTState(self.board, self.to_play)
def legal_actions(self):
return [i for i,v in enumerate(self.board) if v == 0]
def next_state(self, action):
b = self.board[:]
b[action] = self.to_play
return TTState(b, -self.to_play)
def is_terminal(self):
wins = [(0,1,2),(3,4,5),(6,7,8),
(0,3,6),(1,4,7),(2,5,8),
(0,4,8),(2,4,6)]
for i,j,k in wins:
s = self.board[i] + self.board[j] + self.board[k]
if s == 3: return True, 1
if s == -3: return True, -1
if all(v != 0 for v in self.board): return True, 0
return False, 0
# ===================== RUN MCTS =====================
state = TTState()
best_move = mcts(state, iters=2000)
print("Best move:", best_move)
inference optimization - making LLMs fast
Transformers are computationally expensive. Billions of matrix operations per response. For production systems serving thousands of users, optimization isn’t optional.
What this chapter covers:
- KV caching: avoid recomputing attention for past tokens (O(n^2) to O(n))
- Flash attention: optimize memory bandwidth through tiling
- Quantization: reduce precision (FP32 to INT8) for 4x speedup
- Paged attention: efficient memory management for variable-length sequences
Starting with the biggest win: KV caching.
KV cache
The attention matrix grows quadratically with sequence length. KV caching stores keys and values so they don’t need recomputation.
In Transformers, the model processes input tokens in parallel (computes Q,K,V for each token) making sure all tokens can see all other tokens (input) and compute attention score for every pair of position.
input_tokens = [token_1, token_2, ..., token_n]
# Process all tokens at once
for layer in model.layers:
Q, K, V = compute_qkv(input_tokens)
attention_output = attention(Q, K, V)
layer_output = feedforward(attention_output)
Decoder phase starts and model produces tokens one at a time, each new token is based on all previous tokens BUT only latest token needs fresh Q,K,V computations (Q,K,V doesn’t get computed for all previous tokens (that will be a huge complexity of matrix multiplications over and over again when we have already calculated it beforehand (solution to this -> store them in a cache {KV Cache}))
KV Caching & Attention Optimization: From O(n²) to O(n)
KV caching stores previously computed attention keys and values, avoiding recomputation and reducing complexity from O(n²) to O(n) for autoregressive generation.
Say you’re using a language model to generate a sentence, one word at a time.
At every step, the model runs the attention mechanism for all the words until now.
Example: I … (1 x 1 attention matrix)
I am … (2 x 2 attention matrix)
I am going … (3 x 3 attention matrix)
I am going to … (4 x 4 attention matrix)
…
This becomes very slow for long texts.
KV caching reuses previous computations by storing them in memory. Instead of redoing the full attention calculation every time, the model stores the cell values (from attention) for past tokens.
When it generates the next word, it only needs to calculate the query for that word, and reuse the previously saved cell values.
This speeds things up massively. Most production-grade LLMs (like GPT-4, LLaMA, Mistral, etc.) use KV caching
KV caching stores Key (K) and Value (V) vectors for past tokens so they don’t need to be recomputed. O(n^2) to O(n).
why transformer inference is slow
During generation at step t, naïve decoding recomputes:
- hidden states
- all Keys
- all Values
for every token up to t-1.
complexity
- Work per token ≈ (n × L × H)
- Total inference cost ≈ O(n²)
naïve recompute
Step t:
Recompute K/V for tokens [1 ... t-1]
Compute Q_t
Attention(Q_t, K_[1..t-1], V_[1..t-1])
KV caching: core idea
Compute K and V once per token, store them, and reuse later.
Cache_K: [K1, K2, ... Ki]
Cache_V: [V1, V2, ... Vi]
At step t:
Compute K_t, V_t
Append to cache
Use cache for attention
mathematical form
For token i:
K_i = X_i * W_K
V_i = X_i * W_V
Attention now becomes: \(\text{Attn}(Q_t, K_{\text{cache}}, V_{\text{cache}}) = \operatorname{softmax}\!\left(\frac{Q_t\, K_{\text{cache}}^{\mathsf{T}}}{\sqrt{d}}\right) \, V_{\text{cache}}\)
code
class KVCache:
def __init__(self):
self.cache_k = None
self.cache_v = None
def update(self, new_k, new_v):
if self.cache_k is None:
self.cache_k = new_k
self.cache_v = new_v
else:
# Concatenate new K, V with cached values
self.cache_k = concat([self.cache_k, new_k], dim=1)
self.cache_v = concat([self.cache_v, new_v], dim=1)
def get(self):
return self.cache_k, self.cache_v
Example: “I love cats”
without caching
- For “love”: recompute K/V for “I”
- For “cats”: recompute K/V for “I”, “love”
with caching
- Compute K/V for “I” once
- Reuse for all later tokens
current_token = first_generated_token
while not done:
# Only compute for the new token
q_new = compute_query(current_token)
# Retrieve cached K, V from previous tokens
k_cached, v_cached = retrieve_cache()
# Compute attention with cached values
attention_output = attention(q_new, k_cached, v_cached)
next_token = generate_token(attention_output)
current_token = next_token
Naïve: O(n²)
Caching: O(n)
Memory ≈ 2 × L × H × seq_len × head_dim × dtype_size

For each transformer layer and each attention head, the model maintains separate KV caches. When generating the nth token, the cache stores K and V matrices for all n-1 previous tokens.
Obviously as we are storing such huge matrices in cache -> we need memory but that’s okay, it’s a tradeoff between computing them real time vs having memory. Computation trumps Memory when it comes to GPUs.
programming section
Q. KV Cache Code in PyTorch
import torch
import torch.nn as nn
attention = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
tokens = torch.randn(1, 5, 512) # [batch, seq_len, embedding_dim]
# without kv cache: recomputes attention for all tokens (even the ones before the current token at every step)
for t in range(1, tokens.size(1)):
x = tokens[:, :t, :] # tokens from 1 to t
out, _ = attention(x, x, x) # recompute Q,K,V for all past tokens again
# kv cache
past_k, past_v = None, None
for t in range(tokens.size(1)):
x = tokens[:, t:t+1, :] # only the new token
# Project to Q, K, V (like attention does internally)
q = attention.in_proj_q(x)
k = attention.in_proj_k(x)
v = attention.in_proj_v(x)
# save results
past_k = k if past_k is None else torch.cat([past_k, k], dim=1)
past_v = v if past_v is None else torch.cat([past_v, v], dim=1)
# attention uses past_k and and new queries instead of using all keys like before
attention_score = torch.matmul(q, past_k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
attention_logits = attention_score.softmax(dim=-1)
output = torch.matmul(attention_logits, past_v)
Common Mistake: Not clearing KV cache between different conversations causes context contamination. The cache stores ALL previous tokens, so a new query will see tokens from previous conversations, leading to irrelevant information in responses and potential privacy issues. Always clear the cache between users or conversations.
Flash Attention
Tiles of Q and K load into fast SRAM, compute partial attention, and write back - avoiding the full N x N matrix in memory.
some GPU terms
- SRAM (Static Random Access Memory)
- DRAM (Dynamic Random Access Memory)
In the context of GPUs utilizing SRAM, the “static” in SRAM indicates that it does not require refresh cycles, making it suitable for fast access and minimizing latency. Conversely, “dynamic” in DRAM implies that it needs regular refreshing to maintain the stored data, which can contribute to longer access times
Great place to learn GPU terms
Flash Attention quick review
Problem:
- Attention computes large matrices (QKᵀ , softmax, etc.).
- Consumes a lot of memory and has poor hardware efficiency.
Solution:
- Flash Attention fuses softmax + matmul into one GPU kernel.
- Uses tiling and streaming - loads only small parts of Q, K, V into GPU memory at a time.
- Reduces total reads and writes.
- Enables faster attention with lower peak memory.
An LLM uses GPU, tons of parallel instructions, and own memory space (Static RAM) which loads memory from DRAM. This loading is an IO call and is slow.
So loading data from DRAM -> SRAM everytime you perform matrix mul is slow.
- Perform all operations together, get result for a particular Attention Block.
- Each block attention is computed -> results are all aggregated and put together.
Benefits
- 10-20x lower memory consumption
- 2-4x speedup!
- You can perform flash attention along KV Caching and PagedAttention (later)
The bottleneck
GPU spends time doing 2 things -> computing matrices {matrix * matrix} & loading data from memory for previous tokens {matrix * vector}
Attention formulas revisited:
\[S = \frac{QK^{\mathsf{T}}}{\text{scaling_factor}}\] \[\text{attention} = \mathrm{softmax}(S)\] \[\text{output} = \text{attention} \cdot V\]Take the main matrix multiplication operation: Q * K.T
How does this happen on a GPU and why do we need flash attention at all?
gpu time
This operation happens on the HBM (High Bandwidth Memory) outside the GPU cores
- Load Q, K by blocks from HBM -> Compute dot product -> Save result S back to HBM
- Load S from HBM -> Compute attention -> Save results A (attention) back to HBM
- Load A, V from HBM -> Compute output matrix -> Save results O back to HBM
This back and forth travelling from HBM -> GPU Cores back to -> HBM takes a lot of time!!
speeding this up
Memory revision (keep this in mind)
- CPU has DRAM (>1TB )
- GPU has HBM on the outside (40gb)
- GPU on_chip has SRAM (20mb)
Goal is to shift these matrices to on_chip and have it do calculation there itself instead of travelling back and forth. Problem is SRAM is only 20mb and we can’t fit the entire attention matrix in it.
Enter Tiling
Matrix Multiplication Revisited:

For 4x4 matrix
- To compute each value we need to load 1 row value from vector A and 1 column value from vector B:: So total 8 memory accesses and in total (4 rows 4 columns) : 8*4 = 32 memory accesses.
We need this number to go down.
What if we do block multiplication of these matrices? Instead of individually calculating dot products (taking 1 row value and 1 column value) if we take 2x2 blocks of data and compute their matrix multiplication together we would have 8*2 = 16 memory accesses.
So by calculating N*N blocks we reduce memory access by 1/N. This is huge!!

Partition the matrices A, B and C into 2x2 blocks.
Now these 2x2 blocks can be moved to on_chip SRAM for faster processing and then combine these partial tile matrix multiplications to get final results.

How do you apply this to attention? We have already figured out the matrix multiplication part. What about the softmax part? \(\mathrm{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}\)
where xi are the features of the model.
Let’s take 1 row of the attention matrix.

To fix the overflow we use a little trick of subtracting m (the max val of sequence from each input (prevents overflow issue now all values are <= 0)) :: This is known as Safe Softmax.

- Find max
- Calculate sum of all e^x
- Normalize by subtracting m
This process is IO intensive -> iterating the sequence 3 times, how do we reduce this?
We can combine first and second pass by removing subtracting m in the second pass itself.
We find a recurrence relation (again you must have heard this in your algorithms class) and then make the 3 step process -> 2 step

Apply same idea to self attention. \(x_i = q\,k^{\mathsf{T}}\)
xi are the pre-softmax logits computed by dot product, we can use the same technique (instead of 3 passes -> 2 passes)
Applying the same trick again we can reduce 2 passes -> 1 pass hence only 1 access to the memory for softmax.

flash attention: tiling + kernel fusion to keep computation on-chip SRAM
Flash Attention fuses all computations together, and we still get same results.
Recap:
- Partition Matrix (Q,K,V) into tiles. First Q,K,V vectors go to HBM -> onchip SRAM and we perform attention and get o1 (partial result saved to HBM)
- Load next tile, perform attention in SRAM and update partial results from o1 -> o1’
- Repeat this for all tiles
- End of Loop we have all attention results for Q1, similarly we do this for Q2 and all further query vectors.
There are even further optimizations to this like FlashAttention2 and FlashAttention3 Paper, but that is out of scope for this blog.
I know this explanation wasn’t enough so here is -
more on tiling (read this slowly)
“Tiling” which is used to reduce global memory accesses by taking advantage of the shared memory on the GPU. Tiling can be seen as a way to boost execution efficiency of the kernel
Since everything is matrix multiplication, we better make sure it’s as optimized as it can be.
Some basic CUDA terms (CUDA is the software where matrix multiplication and all model compute happens, developed by Nvidia, it is low-level platform for accelerated computing, it is the same as low level memory) you should know if you have taken an OS class.
- Thread: single unit of execution (each thread has its own memory called registers)
- Block: group of threads (think process) : all threads in a block has access to shared memory.
- Grid: group of blocks (access to global memory and constant memory)
Input: Two 4x4 matrices A, B.
Output: 4x4 matrix C.
C consists of 16 elements, where each element is computed through a dot product of a row of A and a column of B, then let’s launch 16 threads, where each thread calculates 1 output element
Threads can be organised into 2x2 block hence 4 blocks in a grid.
What does each thread do? Responsible for loading input elements into shared memory (within each block) :: Hence each of the 4 threads in this block can see what other three threads do. {mini-matrix multiplication using shared memory} :: storing temp result and then continue summing the temp results of next mini-matrix multiplication.
When we are finished with each individual mini-matrix multiplication, each thread would load their corresponding result to the output C element that they are mapped to.
Without tiling: In order to calculate one output element, a thread will need to access one entire row of input A and one entire column of input B, for calculating the dot product. In our example, that is 8 accesses per thread.
With tiling: Each thread ends up loading two elements from input A and two elements from input B, which totals up to 4 accesses per thread.
Say we multiply two large square matrices of size S×S, where S is a multiple of 32. Obviously, the result is also a square matrix of size S×S.
With naïve algorithm, to compute each element of the result, we gonna need to fetch S elements from both matrices. The output matrix has $S^2$ elements, therefore the total count of loaded elements is \(2*S^3\)
With 32×32 tiling, to compute each 32×32 tile of the result, we gonna need to fetch S/32 tiles from both matrices.
The output size in tiles is $(S/32)^2$, the total count of loaded tiles is \(2*(S/32)^3\) Each 32×32 tile contains \(32^2\) elements, the total count of loaded elements is therefore \((32^2)*2*(S/32)^3 = (2/32)*S^3\)
Hence the tiling reduced global memory bandwidth by the factor of 32, which is a huge performance win.
Since every layer has to compute matrix multiplication again and again, we must make it faster::
Efficient matrix multiplication employs a tiling strategy. The large matrix operation gets divided into smaller tiles that fit in shared memory, reducing expensive global memory accesses.
for each 16x16 output tile:
for each 16x16 input tile along K dimension:
# This becomes a single Tensor Core instruction
output_tile += tensor_core_mma_16x16x16(A_tile, B_tile)
programming section
Q. Flash Attention
"""
Minimal FlashAttention example.
Shows how Q/K/V are packed and how flash-attn replaces softmax(QKᵀ)V.
"""
import torch
import torch.nn as nn
try:
from flash_attn.flash_attention import flash_attn_unpadded_qkvpacked_func
FLASH = True
except Exception:
FLASH = False
class FlashMHA(nn.Module):
"""Multi-head attention using FlashAttention (no fallback)."""
def __init__(self, dim, heads, causal=False):
super().__init__()
assert dim % heads == 0
self.dim = dim
self.heads = heads
self.hd = dim // heads
self.causal = causal
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out = nn.Linear(dim, dim, bias=False)
def forward(self, x):
if not FLASH:
raise RuntimeError("flash-attn not installed")
B, T, D = x.shape # (batch, seq, dim)
qkv = self.qkv(x) # (B, T, 3D)
# ---- reshape to packed QKV ----
qkv = qkv.view(B, T, self.heads, 3, self.hd) # (B, T, H, 3, Hd)
qkv = qkv.permute(0, 2, 1, 3, 4).contiguous() # (B, H, T, 3, Hd)
qkv = qkv.view(B * self.heads, T, 3 * self.hd) # (BH, T, 3Hd)
# FlashAttention prefers float16/bfloat16
qkv_dtype = qkv.dtype
if qkv_dtype != torch.float16:
qkv = qkv.half()
# cu_seqlens: prefix sum of sequence lengths for each (batch * head)
cu = torch.arange(0, (B * self.heads + 1) * T, step=T,
dtype=torch.int32, device=x.device)
# ---- FlashAttention call ----
# Computes softmax(QKᵀ)V using tiling (no O(T²) memory)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu, T,
causal=self.causal,
dropout_p=0.0
) # (BH, T, Hd)
# ---- reshape back ----
out = out.view(B, self.heads, T, self.hd)
out = out.permute(0, 2, 1, 3).reshape(B, T, D)
return self.out(out)
if __name__ == "__main__":
B, T, D = 2, 128, 512
x = torch.randn(B, T, D, device="cuda", dtype=torch.float16)
attn = FlashMHA(dim=D, heads=8, causal=False).cuda().half()
y = attn(x)
print("Output:", y.shape) # (2, 128, 512)
paged attention: efficient memory management
Problem
Predicting how big KV cache memory will be is difficult - the next token could be one word, one sentence, or one paragraph. Transformers can’t tell how large the answer is going to be, so we don’t know how large the KV Cache memory needs to be.
In production we have limited memory, if we don’t know how big our context length is -> we can overallocate KV Cache (memory intensive) or underallocate (small KV -> recompute) :: BIG ISSUE
40-60% memory is wasted due to this.
Improvement: Paged attention (if you took an OS class, this will sound familiar!) = Dynamic size of cache for each request (100% utilization) {No cache misses}
Cache -> Virtual Blocks -> Physical Blocks (not contiguous) {Exactly how paging works in OS}


In transformers, KV caching stores token-to-token attention scores in an N x N matrix.
We don’t know how large the output will be, so guessing the final number N is impossible. This means we can’t allocate the exact fixed memory to the matrix (N x N).
Instead we choose a large safe value of N like 2000, resulting in a massive 2000 x 2000 matrix space allocation for a single attention process.
This leads to fragmentation and wasted memory, especially in real-time systems. Solution:
- Use fixed-size memory pages instead of trying to allocate the entire buffer up front.
- Each page stores part of the KV cache. If the output is longer, just add more pages.
Paged Attention an attention algorithm inspired by the classic idea of virtual memory and paging in operating systems. Unlike the traditional attention algorithms, PagedAttention allows storing continuous keys and values in non-contiguous memory space. Specifically, PagedAttention partitions the KV cache of each sequence into blocks, each block containing the keys and values for a fixed number of tokens. During the attention computation, the PagedAttention kernel identifies and fetches these blocks efficiently.

Because the blocks do not need to be contiguous in memory, we can manage the keys and values in a more flexible way as in OS’s virtual memory: one can think of blocks as pages, tokens as bytes, and sequences as processes.
Page attention also does efficient memory sharing so multiple output sequences are generated from same prompt, this helps in chain of thought and computation and memory for prompt is shared (just how processes share physical pages in OS)

Advantage of this (Reasoning):
- Some words are same, attention block can be stored in physical memory and can be shared among multiple processes. 1 physical attention block -> many virtual attention block each comes with different branches springing from original prefix and model can pick the best branch.

MOE (mixture of experts)
Only activate a small part of the model for any given input.
Treat the neural network as a graph and activate a subgraph based on the input query. A query about shopping triggers a different part of the model than a query about emotions. Only a few “experts” are active at any time.
Less compute, same capability.
Activate subgraphs which are expert in those queries (Kind of like load balancer)
Advantage: Number of neurons that are going to be fired are less but capability is same as LLM
Gate chooses top1/top2 subgraph which has the answer. More on this




The router decides which expert sub-model to choose for a query typically during the inference phase rather than the training phase.
During training, the model might use a labeled dataset to learn how to associate specific queries with the appropriate sub-models, often employing techniques such as routing algorithms to optimize for query handling
Problem:
- Scaling LLMs with more parameters = more compute.
- But you don’t always need the full model for every input. A small section of the LLM would suffice, depending on the input query.
Solution:
- Split the model into experts (independent subnetworks).
- Use a gating mechanism to activate only a few experts per input.
- Reduces compute but keeps the model expressive.
Example:
- Input 1: “Translate to French” → Activates Experts 2 and 4
- Input 2: “Summarize this text” → Activates Experts 1 and 5
Each forward pass uses only ~10% of the total parameters.
optimisation techniques with tradeoffs
1. Quantisation : Converting higher bits -> lower bits of model floating points numbers without losing much precision and accuracy.
Taking weights of model and representing them using fewer bits.
The cache in KV Cache grows linearly with sequence length. For a 13-billion parameter model like LLaMA-2, each output token requires approximately 1 MB of cache storage.
A 4,000 token context needs about 4 GB just for the cache, comparable to the model size itself.
Way to improve this is to quantise (FP16 -> FP8 : basically lowering the amount of bits you use to store vectors without losing precision / accuracy) or using sliding window attention that only retains recent tokens or implementing attention approximations that reduce cache requirements.
# Precision formats
FP32: 1 sign bit, 8 exponent bits, 23 mantissa bits
FP16: 1 sign bit, 5 exponent bits, 10 mantissa bits
BF16: 1 sign bit, 8 exponent bits, 7 mantissa bits
INT8: 8 bits for integer representation
INT4: 4 bits for integer representation
A 7-billion parameter model at FP16 precision requires approximately 14 GB of memory (7B parameters × 2 bytes per parameter). Quantizing to INT4 reduces this to 3.5 GB, enabling inference on consumer hardware.
beautiful blog on quantisation
Q. Perform Quantisation on vanilla model.pt -> Convert to FP16 (32 bits -> 16 bits)
Trained model -> .pt
- First convert it to .onnx (platform independent model file)
import torch
import torch.nn as nn
model = torch.load("model.pt", map_location="cpu") # your model
model.eval()
dummy_input = torch.randn(1, 3, 640, 640) # default img size is 640x640 pixels
onnx_path = "model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
input_names=["input"],
output_names=["output"],
opset_version=12,
do_constant_folding=True,
dynamic_axes={
"input": {0: "batch", 2: "height", 3: "width"},
"output": {0: "batch"}
}
)
Then convert to TensorRT (model.plan) using trtexec engine or python script below:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
def build_engine(onnx_file, engine_file):
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
# Parse ONNX
with open(onnx_file, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
return
config = builder.create_builder_config()
# Enable FP16
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# Optimization profile for dynamic input size
profile = builder.create_optimization_profile()
input_tensor = network.get_input(0)
profile.set_shape(
input_tensor.name,
(1, 3, 640, 640), # min
(1, 3, 640, 640), # opt
(4, 3, 640, 640), # max (batch=4)
)
config.add_optimization_profile(profile)
print("Building TensorRT Engine")
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
raise RuntimeError("Engine build failed!")
with open(engine_file, "wb") as f:
f.write(engine_bytes)
build_engine("model.onnx", "model_fp16.plan")
Now your model goes from .pt -> FP16 .plan which is significantly faster at inference/prediction time.
Typical results (RTX 3090 / A100 class GPUs):
- ResNet-50: 4× faster
- YOLOv5/YOLOv8 Object Detection: 2.5-3x faster
- Transformers: 2-3x faster
Memory depends on : total weights, how deep those connections are (how many bits are required to store weights)
- Since weights are smaller, you can store more weights in cache -> faster inference.
When to perform quantisation?
- Step1: Train LLM
- Step2: Answer queries (You perform the quantisation part here)
2. Sparse Attention
Take one token and only focus on those queries that are near it.
Idea: relevant context for a token lies near the token itself, we don’t have to search the entire space.

For damaged we only need to look at the damaged column for relevant context.
Example
In Harry Potter novel, if he’s angry, hes angry recently (context) not 5 pages ago (and its being addressed now (highly improbable))

Attention is only computed for these (diagonal) cells in the matrix
This significantly reduces computed : O(nn) -> O(nw) where w is the diagonal length
-
Drawback context window is shorter (if from 3 pages ago it won’t pick that up) {but it’s a tradeoff}
- If input is only a sentence -> dense / normal attention
- If input is an entire document (harry potter book) -> sparse attention


Smarter way to do this

- Drawback Tradeoff between accuracy and speed
- Flash Attention is computed using GPUs, in sparse attn -> you have this thick diagonal but how do you send it to the GPU.
- Look these random cells -> GPU doesn’t get a complete block it only gets cells, then flash attention doesn’t work because that only works in blocks (FA)
- So you can have random blocks for sparse attention and local attention can be found and then entire block can be sent to GPUs for flash attention to work.
- But this is harder to implement and not widely used.
- GPU centric optimizations is hard (deepseek has made it work)
- SLM and Distillation
3. Q. Distillation (Optimizations used in deployment)
- Reduce number of parameters (weights) in model.
- Reduce parameters in FFN : smaller model which also works well (Small Language Model) 100 Billion parameter -> 100 million parameter (lossy (less intelligent) but covers most use cases as long as you fine tune the model.
This is called distillation.
Large language model: Teacher, student is the smaller language model.
Same input for both model, see output for both and compare outputs (difference & backpropagation fixes it)

Student tries to mimics teachers (kind of like GANs)
- Drawback: reduced intelligence.
- Speculative Decoding
LLMs generate one token at a time. This creates latency, especially for long answers.
Speculative Decoding uses a smaller model to speculate on the next few tokens. The large model verifies or corrects them in parallel.
The draft model proposes tokens fast, the verifier checks them all in one pass - accepting most, correcting the rest.
Process:
- Small model predicts the next 4-8 tokens.
- Large model checks these predictions.
- If valid, the LLM accepts; else correct this token and send to SLM for further processing.
Example: “What is 319 ÷ 215?”
Small model predicts: “The result is 2.59.
Large model verifies all tokens in parallel (green / red are predictions).
- Verification 1 –
(The) - Verification 2 – The- (result)
- Verification 3 – The result- (is)
- Verification 4 – The result is- (1.48)
It finds the token “2.59” incorrect, and replaces it with “1.48”
Large model now asks Small model to predict the rest of the sequence.
This happens recursively until full output is generated.
This results in 2x faster decoding without accuracy loss.
Different tokens require different efforts
- Query : what is 5+2
- Output: the answer to 5+2 is 7 (most effort required was at 7, rest were easy)
This works because gpt is trained on facts so if it has seen 5+2 many times on the internet the next token is probably 7 and hence it guesses that, it doesn’t mathematically compute or anything, it is only trained on data.
Corrections required are done by LLM, rest is done by small language model, basically teacher who checks output of student in every token. simple yet brilliant concept.

optimization techniques comparison
| Technique | Speedup | Memory Impact | Quality Loss |
|---|---|---|---|
| KV Cache | 3-5x | +80% (cache growth) | None |
| Flash Attention | 2-4x | -50% (peak usage) | None |
| Quantization (INT8) | 2-3x | -75% | <1% accuracy loss |
| Quantization (INT4) | 3-4x | -87% | 2-3% accuracy loss |
| Paged Attention | 1.2-1.5x | +20% utilization | None |
| Sparse Attention | 2-10x | -60% | 2-5% accuracy loss |
| Speculative Decoding | 2-3x | +30% (dual models) | None |
| Distillation (SLM) | 5-10x | -90% | 10-20% capability loss |
| MOE | 2-4x | 0% (sparse activation) | None |
Key Concept: Optimizations compound! Combining KV cache (3x) + Flash attention (2x) + INT8 (2x) = ~12x total speedup. Don’t optimize in isolation - apply multiple techniques together for maximum impact.
RAG, AI agents & production deployment
Gaurav explains AI Agents and RAG pretty well in his live sessions, I would recommend you watch those directly, but here is whatever I learnt:
Knowing how LLMs work is nice but actually deploying them is a different game. This chapter covers:
- RAG (Retrieval Augmented Generation): ground model outputs in retrieved documents
- AI Agents: give LLMs access to tools (search, calculators, APIs)
- Cost optimization: make production deployment not bankrupt you
- Prompt engineering: get better outputs through better inputs
the journey: text → tokens → vectors → LLMs
Complete flow of how text reaches a language model:
1) User input: “What is the best fictional book?”
2) Tokenization: Text is split into subword units (tokens)
3) Vector retrieval: Query is converted to vector, similar documents fetched from vector DB
4) Context assembly: User query + relevant documents are tokenized together
5) LLM processing: Combined tokens are processed by the transformer
6) Output generation: Model produces answer token by token: “Harry Potter”
User query + relevant context from vector DB get tokenized together and passed to the LLM. That’s RAG in a nutshell.
-
The user asks a question/query
-
This question is sent to a vector-DB from where a document closely related to query is fetched
-
This document along with query is sent to LLM.
-
The llm reads text from document. This will be the context for the question. Using this context, the llm will answer the user’s question/ query.
-
The user will receive appropriate output.
So my query of what is the best fictional book along with relevant context (all fictional book documents) are sent to LLMs and answer is generated (Harry Potter)
RAG
Information not necessarily in weights but you give it external relevant sources of data. In this way it is able to AUGMENT the knowledge base of LLM with relevant documents.

So every LLM learns in 3 ways:
- Model pre training (data + gpu)
-
Model fine tuning (gpu + data)
- Passing contextual info (data)
RAG works on contextual information passing.
Think ChatGPT but with your custom notes for OS. How easy would it be to study for the exam.
The model can answer all queries using your notes. Why is this so benefital? Instead of giving it a very big sample space of data to search for, we limit it (reducing $$ and GPU compute required)
This is benefitial for small companies that want to build a customer support chatbot which accesses all company relevant documents and only answers using them (eg: Zomato, MMT) instead of building a LLM from scratch which gives vague and general answers instead of niche company answers.
So you can have a personalised chatbot at the fraction of the cost of fine tuning.
So how does RAG work? Say you give it a prompt “My order #3129 wasn’t delivered but shows out for delivery, can you help me with this” on Zomato. Here’s the flow::
- Index + Embedding
- Since data is huge and cant pass everything in model prompt (> context length) so only give it relevant data.
- Vector Store DB
- Data is stored in chunks (which are converted to embeddings (words which are numbers)
- Chunking happens by sentence or paragraph (there are multiple ways of doing this but just assume fixed size chunking for now (chunk of 100 words each))
-
Contextual information about the document/external database is stored in vectors. Vector databases are designed for efficient storage of vectors which facilitates query interface that can retrieve at higher search rates and speeds.
. By controlling the chunk size, RAG can maintain a balance between comprehensiveness and precision.
- Filtering + Ranking
- Improving data quality :: Filtering
- Content based filtering: filter out doc that don’t contain ‘xyz’
- Metadata filtering : filter out doc based on author, source link etc.
- Threshold filter: after initial retrieval, docu are filtered on basis of threshold similarity socre
- Re-rank is used to order filtered doc to priortise those are relevant.
- Semmantic Re-ranking:reorder based on semantic content (emotion)
- Contextual re-ranking: reorder based on contextual info
- LTR (learning to rank) : ML model used (LambdaNet) trained to predict relevant of documents (features: doc length, chat history, term frequency)
- Improving data quality :: Filtering
- Retrieval
- Searching can be via KNN, ANN, Tree Based (Annoy), Clustering (FAISS) etc..
- Prompt Augmentation
- Ingestion of retrived doc in prompt of model so model can generate response
- Generation
- Personalised + high quality result according to user need by comvining existing knowledge + relevant data in prompt
Graph databases are preferred over vector databases (they allow for semantically similar vector searches but way fetch irrelevant data) :: Graph extracts entity-relationship data but it also requires exact query matching (no semantic similarity) so both have their tradeoffs.
After retrieving -> filter by adding ranking and only keep top-k results.


reducing hallucinations with RAG
RAG helps but LLMs still hallucinate. Some strategies:
1. Ground prompts with facts
- Use RAG to feed real documents into context
- Provide specific, relevant sources rather than generic context
2. Reduce prompt scope
- Don’t stuff 20 documents into every query
- Smaller, targeted context = sharper, more accurate responses
- Use semantic search to find the MOST relevant chunks
3. Force answer shape
- Add explicit instructions: “Answer ONLY based on documents provided”
- “If information is not in the context, say ‘I don’t have enough information’”
- Constrain the model’s behavior through prompts
4. Validation and post-processing
- Apply business rules: “If model says ‘7 days’, verify against actual policy”
- Use confidence scores to flag uncertain responses
- Human-in-the-loop for critical decisions
Example - Refund Policy Chatbot:
Bad:
User: "What's your refund policy?"
LLM: [Generates plausible-sounding but potentially wrong answer from training data]
Good with RAG:
System: You are answering based on official policy. Use ONLY the text below.
Context: [Actual refund policy from company database]
User: "What's your refund policy?"
LLM: [Cites specific policy, includes source]
RAG + good prompt engineering = way fewer hallucinations.
AI agents
- Perform tasks using external tools (web search, db query) when LLM isn’t enough.
- Difference between Agents and Chatbots? AI Agent activates tools to provide better results when unsure (unlike chatbot which produces slop)
Instead of writing stuff here I’ll just write the code for all things AI Agents can do, the code is simple and self-explanatory although there are comments when needed
# pip install ollama==0.4.7
import ollama
llm = "qwen2.5" # open sourced model
stream = ollama.generate(model=llm, prompt="whats up?", stream=True)
# gpt produces chunks of information like discussed before, for each chunk in the output stream -> print the chunk and clear the memory (flush=True)
for chunk in stream:
print(chunk['response'], end='', flush=True)
# search the internet
# pip install duckduckgo-search==6.3.5
# pip install langchain-community==0.3.17
from langchain_community.tools import DuckDuckGoSearchResults
def search(query):
# search for image
return DuckDuckGoSearchResults(backend="images").run(query)
tool_search = {
'type': 'function',
'function': {
'name': "search",
'parameters' : {
'type': 'object',
'required': ['query'],
'properties': {
'query' : {'type': 'string', 'description': 'image to search on web'},
}
}
}
}
search(query='pandas')
# what if I want to search only for financial updates?
def search_finance(query: str) -> str:
engine = DuckDuckGoSearchResults(backend="news")
return engine.run(f"site:finance.yahoo.com {query}")
tool_search = {
'type':'function',
'function':{
'name': 'search_yf',
'description': 'Search for specific financial news',
'parameters': {
'type': 'object',
'required': ['query'],
'properties': {
'query': {'type':'string', 'description':'the financial topic or subject to search'},
}
}
}
}
search_finance(query="apple")
# you can do the same for making it code, all things follow similar structure
import io
import contextlib
def coding_agent(code):
output = io.StringIO()
# whatever code in the context -> run and get the value
with contextlib.redirect_stdout(output):
exec(code)
return output.getvalue()
tool_code_exec = {
'type':'function',
'function':{
'name': 'code_exec',
'description': 'execute python code',
'parameters':
{
'type': 'object',
'required': ['code'],
'properties': {
'code': {'type':'string', 'description':'code to execute'},
}
}
}
}
coding_agent("multi = 8*8; print(multi)")
# You can also add memory to these AI agents.
prompt = """You are an expert software engineer.
Write clean Python code when asked.
Always compute Fibonacci numbers using the function fib(n)."""
messages = [{"role": "system", "content": prompt}]
start = True
while True:
try:
if start:
q = input("Enter a number for Fibonacci (or 'quit'): ")
else:
q = input("> ")
except EOFError:
break
if q.strip() == "" or q.lower() == "quit":
break
# Build user prompt
user_message = f"Solve fibonacci for n = {q}"
messages.append({"role": "user", "content": user_message})
# Call model
agent_res = ollama.chat(
model=llm,
messages=messages
)
res = agent_res["message"]["content"]
print(">", res)
# Save assistant response
messages.append({"role": "assistant", "content": res})
start = False
- As you can tell it’s pretty easy to get started with AI agents and make your own. This is just an application of LLMs, nothing much to write about here. You can see great examples of AI Agents here
prompt engineering: getting better outputs
zero-shot vs few-shot prompting
Zero-Shot: Give the model a task with no examples
"List the best horror movies"
Few-Shot: Provide examples to guide the response format
"List the best movies in each genre:
Rom-Com:
1. When Harry Met Sally
2. The Proposal
Comedy:
1. Superbad
2. The Hangover
Horror:
[model generates this]"
Few-shot prompting is almost always better - the model picks up on the format and style from your examples.
chain of thought in prompts
You can explicitly ask the model to show its reasoning:
"What's the best horror movie?
Think step by step:
1. Consider highly rated horror films (IMDB > 8.0)
2. Factor in user preferences (psychological thriller vs slasher)
3. Check recency (prefer newer releases)
4. Combine rankings
Answer: [model's response]"
Already covered this in Chapter 4 but it works just as well in prompts.
system prompts + context engineering
The Problem: System prompts alone have limitations:
- Can’t contain all relevant information (token limits)
- Lack real-time data
- Missing user-specific context
The Solution: Combine multiple sources:
User Query + Relevant Docs (RAG) + System Prompt + Chat History → LLM
Context Engineering: As chat history grows, you can’t pass everything to the LLM (memory limits). Strategies:
- Summarization: Compress chat history into key points
- Sliding window: Keep only recent N messages
- Importance filtering: Extract most relevant information
Keep the context, dodge the token limits.
Example - Refund Policy Bot:
System: You are answering a customer based on our official policy.
Use the text below to generate your answer. Do not guess.
Context: [2 paragraphs from policy doc]
Customer Query: When will I get a refund?
Response: [model generates accurate, grounded answer]
cost optimization strategies
LLMs in production get expensive fast. Ways to reduce costs:
1. caching
- Vector-based caching: Common queries (“What’s the weather?”) get cached responses
- Store query vectors, return cached answers for similar queries
- Works like a hashmap: Key = Query Vector, Value = Response
2. model selection
- Use smaller models when appropriate (Mistral, DeepSeek instead of GPT-4)
- Reserve expensive models for complex tasks
3. token reduction
- Reduce context window size
- Compress system prompts
- Limit chat history
4. rate limiting
- Per-user query limits
- Per-session throttling
5. AI agents for better ux
User asks: “What size is this t-shirt?” What they really want: “Will this fit me?”
AI agents can understand intent and use context (previous purchases, size history) to answer in one shot instead of multiple back-and-forth queries. Fewer tokens, better experience.
6. security & privacy
- Anonymize queries: Don’t store personally identifiable information
- Filter sensitive data: Remove credit card numbers, emails before logging
- Data retention policies: Delete query logs after N days
conclusion
That’s basically everything I picked up from speedrunning the course. The whole pipeline is: turn text into vectors, store them in a vector DB, tokenize and feed into transformers (which are just attention + FFN stacked a bunch of times), optimize the heck out of inference with KV caching and flash attention, then RAG and agents on top to make it actually useful. If you made it this far, nice (lol).

resources + references
Visualizations:
- 3D LLM Visualization - See transformers in action
- Transformer Illustrated - Jay Alammar’s beautiful visual guide
- Word2Vec Illustrated - Understanding embeddings
Papers to Read:
- “Attention Is All You Need” (Vaswani et al., 2017) - The original transformer paper
- “FlashAttention” (Dao et al., 2022) - Making attention fast
- “LLaMA” (Touvron et al., 2023) - Modern open LLM architecture
Courses & Tutorials:
- GKCS AI Engineering course (InterviewReady)
- Andrej Karpathy’s lectures on neural networks and LLMs
Blogs & Visualizations:
- Best blog on Attention
- The Illustrated Transformer - Jay Alammar
- The Illustrated Word2Vec - Jay Alammar
- LLM Visualization - Brendan Bycroft
- Modal GPU Glossary - GPU terminology reference
- Awesome AI Agents - Curated agent examples
Tools & Libraries:
- TikTokenizer - Visualize tokenization
- Milvus Documentation - Vector database architecture
- vLLM Blog - Paged attention explained
Attention Matrix @ Value to get final matrix
—
Hiding future words at inference / prediction time so model doesn’t cheat and predicts words which it has been trained to do
