KV Cache Walkthrough¶
This notebook unrolls a tiny autoregressive attention example cell by cell.
The walkthrough now uses an explicit batch dimension B = 1, so the shapes match standard transformer notation more closely:
- token IDs:
(B, T) - embeddings and hidden states:
(B, T, d) - attention scores:
(B, T, T) - logits for the newest position:
(B, V)
For each decoding iteration, the work is split into:
- embedding lookup
- Q/K/V projections
- attention scores
- masking
- softmax and value aggregation
- output projection and token sampling
Run the notebook from top to bottom so later cells can reuse earlier tensors.
import math
import numpy as np
import torch
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(suppress=True)
batch_size = 1 # -> (scalar) = (B=1)
vocab_size = 10 # -> (scalar) = (V=10)
d_model = 3 # -> (scalar) = (d=3)
max_sequence_length = 5 # -> (scalar) = (T_max=5)
token_embedding_table = torch.tensor(
[
[-1.0, -2.0, 0.0],
[0.0, -2.0, -1.0],
[-1.0, -2.0, 2.0],
[-2.0, 1.0, -1.0],
[0.0, 1.0, -2.0],
[2.0, 2.0, 1.0],
[2.0, 2.0, 2.0],
[-2.0, 0.0, 2.0],
[-2.0, 1.0, 1.0],
[0.0, 1.0, -2.0],
],
dtype=torch.float32,
)
W_q = torch.tensor(
[
[2.0, 2.0, 1.0],
[-1.0, 1.0, -1.0],
[0.0, 1.0, 2.0],
],
dtype=torch.float32,
)
W_k = torch.tensor(
[
[-1.0, -2.0, 2.0],
[-2.0, 2.0, 2.0],
[0.0, -2.0, -2.0],
],
dtype=torch.float32,
)
W_v = torch.tensor(
[
[-1.0, 2.0, 2.0],
[2.0, -2.0, 2.0],
[0.0, 0.0, 0.0],
],
dtype=torch.float32,
)
W_o = torch.tensor(
[
[-1.0, 2.0, 2.0],
[-1.0, -1.0, -2.0],
[1.0, -2.0, -1.0],
],
dtype=torch.float32,
)
W_vocab = torch.tensor(
[
[0.0, -1.0, 2.0, -1.0, -1.0, 2.0, -1.0, -2.0, -2.0, 0.0],
[-2.0, 1.0, -1.0, 1.0, -1.0, 0.0, 1.0, -2.0, 1.0, 1.0],
[2.0, 0.0, 0.0, -2.0, 1.0, 1.0, -1.0, 1.0, -2.0, 1.0],
],
dtype=torch.float32,
)
start_prefix = torch.tensor([[1]], dtype=torch.long) # -> (1, 1) = (B=1, T=1)
prefix_lengths = np.array([1, 2, 3, 4], dtype=np.int64) # -> (4,) = (num_steps=4,)
print('Initial prompt token IDs:', start_prefix.tolist())
print('Naive tokens processed for Q/K/V per step:', prefix_lengths.tolist(), 'total =', int(prefix_lengths.sum()))
print('Cached tokens processed for Q/K/V per step:', np.ones_like(prefix_lengths).tolist(), 'total =', int(np.ones_like(prefix_lengths).sum()))
print('Naive attention scores computed per step:', (prefix_lengths ** 2).tolist(), 'total =', int((prefix_lengths ** 2).sum()))
print('Cached attention scores computed per step:', prefix_lengths.tolist(), 'total =', int(prefix_lengths.sum()))
Initial prompt token IDs: [[1]] Naive tokens processed for Q/K/V per step: [1, 2, 3, 4] total = 10 Cached tokens processed for Q/K/V per step: [1, 1, 1, 1] total = 4 Naive attention scores computed per step: [1, 4, 9, 16] total = 30 Cached attention scores computed per step: [1, 2, 3, 4] total = 10
Shape Notation¶
The inline comments show both concrete dimensions and symbolic dimensions.
B: batch size, hereB = 1T: current sequence length or cache length at this stepT_old: previous cache length before appending the newest tokenT_new: number of newly decoded tokens at this step, which is always1hered: embedding and hidden dimension, hered = 3V: vocabulary size, hereV = 10
Naive Decoding¶
At every step, recompute the full prefix from scratch.
Iteration 1¶
Start from the one-token prompt [[1]] and compute the first next-token prediction with a full forward pass.
# Naive iteration 1: embedding lookup
naive_prefix_1 = start_prefix # (1, 1) = (B=1, T=1)
naive_x_1 = token_embedding_table[naive_prefix_1] # (1, 1) -> (1, 1, 3) = (B=1, T=1) -> (B=1, T=1, d=3)
print('Naive iteration 1 prefix:', naive_prefix_1.tolist())
print('Naive iteration 1 embeddings:\n', naive_x_1)
Naive iteration 1 prefix: [[1]] Naive iteration 1 embeddings: tensor([[[ 0., -2., -1.]]])
# Naive iteration 1: Q/K/V projections
naive_Q_1 = naive_x_1 @ W_q # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T=1, d=3) x (d=3, d=3) -> (B=1, T=1, d=3)
naive_K_1 = naive_x_1 @ W_k # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T=1, d=3) x (d=3, d=3) -> (B=1, T=1, d=3)
naive_V_1 = naive_x_1 @ W_v # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T=1, d=3) x (d=3, d=3) -> (B=1, T=1, d=3)
print('Naive Q_1:\n', naive_Q_1)
print('Naive K_1:\n', naive_K_1)
print('Naive V_1:\n', naive_V_1)
Naive Q_1: tensor([[[ 2., -3., 0.]]]) Naive K_1: tensor([[[ 4., -2., -2.]]]) Naive V_1: tensor([[[-4., 4., -4.]]])
# Naive iteration 1: raw attention scores
naive_scores_1 = naive_Q_1 @ naive_K_1.transpose(-2, -1) / math.sqrt(d_model) # (1, 1, 3) x (1, 3, 1) -> (1, 1, 1) = (B=1, T=1, d=3) x (B=1, d=3, T=1) -> (B=1, T=1, T=1)
print('Naive scores_1:\n', naive_scores_1)
Naive scores_1: tensor([[[8.0829]]])
# Naive iteration 1: causal masking
naive_mask_1 = torch.tril(torch.ones((1, 1), dtype=torch.bool)).unsqueeze(0) # (1, 1) -> (1, 1, 1) = (T=1, T=1) -> (B=1, T=1, T=1)
naive_scores_1_masked = naive_scores_1.masked_fill(~naive_mask_1, float('-inf')) # (1, 1, 1) with (1, 1, 1) -> (1, 1, 1) = (B=1, T=1, T=1) with (B=1, T=1, T=1) -> (B=1, T=1, T=1)
print('Naive mask_1:\n', naive_mask_1)
print('Naive masked scores_1:\n', naive_scores_1_masked)
Naive mask_1: tensor([[[True]]]) Naive masked scores_1: tensor([[[8.0829]]])
# Naive iteration 1: softmax and value aggregation
naive_weights_1 = torch.softmax(naive_scores_1_masked, dim=-1) # (1, 1, 1) -> (1, 1, 1) = (B=1, T=1, T=1) -> (B=1, T=1, T=1)
naive_context_1 = naive_weights_1 @ naive_V_1 # (1, 1, 1) x (1, 1, 3) -> (1, 1, 3) = (B=1, T=1, T=1) x (B=1, T=1, d=3) -> (B=1, T=1, d=3)
print('Naive weights_1:\n', naive_weights_1)
print('Naive context_1:\n', naive_context_1)
Naive weights_1: tensor([[[1.]]]) Naive context_1: tensor([[[-4., 4., -4.]]])
# Naive iteration 1: output projection and sampling
naive_hidden_1 = naive_context_1 @ W_o # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T=1, d=3) x (d=3, d=3) -> (B=1, T=1, d=3)
naive_logits_1 = naive_hidden_1[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
naive_next_token_1 = torch.argmax(naive_logits_1, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
naive_prefix_2 = torch.cat([naive_prefix_1, naive_next_token_1.unsqueeze(1)], dim=1) # (1, 1) + (1, 1) -> (1, 2) = (B=1, T=1) + (B=1, T_new=1) -> (B=1, T=2)
print('Naive hidden_1:\n', naive_hidden_1)
print('Naive logits_1:\n', naive_logits_1)
print('Naive iteration 1 next token:', naive_next_token_1.tolist())
Naive hidden_1: tensor([[[ -4., -4., -12.]]]) Naive logits_1: tensor([[-16., 0., -4., 24., -4., -20., 12., 4., 28., -16.]]) Naive iteration 1 next token: [8]
Iteration 2¶
Now the prefix has two tokens. The naive path recomputes the old token's work as well as the new token's work.
# Naive iteration 2: embedding lookup
naive_x_2 = token_embedding_table[naive_prefix_2] # (1, 2) -> (1, 2, 3) = (B=1, T=2) -> (B=1, T=2, d=3)
print('Naive iteration 2 prefix:', naive_prefix_2.tolist())
print('Naive iteration 2 embeddings:\n', naive_x_2)
Naive iteration 2 prefix: [[1, 8]]
Naive iteration 2 embeddings:
tensor([[[ 0., -2., -1.],
[-2., 1., 1.]]])
# Naive iteration 2: Q/K/V projections
naive_Q_2 = naive_x_2 @ W_q # (1, 2, 3) x (3, 3) -> (1, 2, 3) = (B=1, T=2, d=3) x (d=3, d=3) -> (B=1, T=2, d=3)
naive_K_2 = naive_x_2 @ W_k # (1, 2, 3) x (3, 3) -> (1, 2, 3) = (B=1, T=2, d=3) x (d=3, d=3) -> (B=1, T=2, d=3)
naive_V_2 = naive_x_2 @ W_v # (1, 2, 3) x (3, 3) -> (1, 2, 3) = (B=1, T=2, d=3) x (d=3, d=3) -> (B=1, T=2, d=3)
assert torch.allclose(naive_Q_2[:, :1, :], naive_Q_1)
assert torch.allclose(naive_K_2[:, :1, :], naive_K_1)
assert torch.allclose(naive_V_2[:, :1, :], naive_V_1)
print('Naive Q_2:\n', naive_Q_2)
print('Naive K_2:\n', naive_K_2)
print('Naive V_2:\n', naive_V_2)
Naive Q_2:
tensor([[[ 2., -3., 0.],
[-5., -2., -1.]]])
Naive K_2:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.]]])
Naive V_2:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.]]])
# Naive iteration 2: raw attention scores
naive_scores_2 = naive_Q_2 @ naive_K_2.transpose(-2, -1) / math.sqrt(d_model) # (1, 2, 3) x (1, 3, 2) -> (1, 2, 2) = (B=1, T=2, d=3) x (B=1, d=3, T=2) -> (B=1, T=2, T=2)
assert torch.allclose(naive_scores_2[:, :1, :1], naive_scores_1)
print('Naive scores_2:\n', naive_scores_2)
Naive scores_2:
tensor([[[ 8.0829, -6.9282],
[-8.0829, -2.3094]]])
# Naive iteration 2: causal masking
naive_mask_2 = torch.tril(torch.ones((2, 2), dtype=torch.bool)).unsqueeze(0) # (2, 2) -> (1, 2, 2) = (T=2, T=2) -> (B=1, T=2, T=2)
naive_scores_2_masked = naive_scores_2.masked_fill(~naive_mask_2, float('-inf')) # (1, 2, 2) with (1, 2, 2) -> (1, 2, 2) = (B=1, T=2, T=2) with (B=1, T=2, T=2) -> (B=1, T=2, T=2)
print('Naive mask_2:\n', naive_mask_2)
print('Naive masked scores_2:\n', naive_scores_2_masked)
Naive mask_2:
tensor([[[ True, False],
[ True, True]]])
Naive masked scores_2:
tensor([[[ 8.0829, -inf],
[-8.0829, -2.3094]]])
# Naive iteration 2: softmax and value aggregation
naive_weights_2 = torch.softmax(naive_scores_2_masked, dim=-1) # (1, 2, 2) -> (1, 2, 2) = (B=1, T=2, T=2) -> (B=1, T=2, T=2)
naive_context_2 = naive_weights_2 @ naive_V_2 # (1, 2, 2) x (1, 2, 3) -> (1, 2, 3) = (B=1, T=2, T=2) x (B=1, T=2, d=3) -> (B=1, T=2, d=3)
print('Naive weights_2:\n', naive_weights_2)
print('Naive context_2:\n', naive_context_2)
Naive weights_2:
tensor([[[1.0000, 0.0000],
[0.0031, 0.9969]]])
Naive context_2:
tensor([[[-4.0000, 4.0000, -4.0000],
[ 3.9752, -5.9690, -2.0062]]])
# Naive iteration 2: output projection and sampling
naive_hidden_2 = naive_context_2 @ W_o # (1, 2, 3) x (3, 3) -> (1, 2, 3) = (B=1, T=2, d=3) x (d=3, d=3) -> (B=1, T=2, d=3)
naive_logits_2 = naive_hidden_2[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
naive_next_token_2 = torch.argmax(naive_logits_2, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
naive_prefix_3 = torch.cat([naive_prefix_2, naive_next_token_2.unsqueeze(1)], dim=1) # (1, 2) + (1, 1) -> (1, 3) = (B=1, T=2) + (B=1, T_new=1) -> (B=1, T=3)
print('Naive hidden_2:\n', naive_hidden_2)
print('Naive logits_2:\n', naive_logits_2)
print('Naive iteration 2 next token:', naive_next_token_2.tolist())
Naive hidden_2:
tensor([[[ -4.0000, -4.0000, -12.0000],
[ -0.0124, 17.9318, 21.8946]]])
Naive logits_2:
tensor([[ 7.9256, 17.9442, -17.9566, -25.8450, 3.9752, 21.8698, -3.9504,
-13.9442, -25.8326, 39.8264]])
Naive iteration 2 next token: [9]
Iteration 3¶
The prefix grows again, and the naive path recomputes the first two tokens before producing the third prediction.
# Naive iteration 3: embedding lookup
naive_x_3 = token_embedding_table[naive_prefix_3] # (1, 3) -> (1, 3, 3) = (B=1, T=3) -> (B=1, T=3, d=3)
print('Naive iteration 3 prefix:', naive_prefix_3.tolist())
print('Naive iteration 3 embeddings:\n', naive_x_3)
Naive iteration 3 prefix: [[1, 8, 9]]
Naive iteration 3 embeddings:
tensor([[[ 0., -2., -1.],
[-2., 1., 1.],
[ 0., 1., -2.]]])
# Naive iteration 3: Q/K/V projections
naive_Q_3 = naive_x_3 @ W_q # (1, 3, 3) x (3, 3) -> (1, 3, 3) = (B=1, T=3, d=3) x (d=3, d=3) -> (B=1, T=3, d=3)
naive_K_3 = naive_x_3 @ W_k # (1, 3, 3) x (3, 3) -> (1, 3, 3) = (B=1, T=3, d=3) x (d=3, d=3) -> (B=1, T=3, d=3)
naive_V_3 = naive_x_3 @ W_v # (1, 3, 3) x (3, 3) -> (1, 3, 3) = (B=1, T=3, d=3) x (d=3, d=3) -> (B=1, T=3, d=3)
assert torch.allclose(naive_Q_3[:, :2, :], naive_Q_2)
assert torch.allclose(naive_K_3[:, :2, :], naive_K_2)
assert torch.allclose(naive_V_3[:, :2, :], naive_V_2)
print('Naive Q_3:\n', naive_Q_3)
print('Naive K_3:\n', naive_K_3)
print('Naive V_3:\n', naive_V_3)
Naive Q_3:
tensor([[[ 2., -3., 0.],
[-5., -2., -1.],
[-1., -1., -5.]]])
Naive K_3:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.],
[-2., 6., 6.]]])
Naive V_3:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.],
[ 2., -2., 2.]]])
# Naive iteration 3: raw attention scores
naive_scores_3 = naive_Q_3 @ naive_K_3.transpose(-2, -1) / math.sqrt(d_model) # (1, 3, 3) x (1, 3, 3) -> (1, 3, 3) = (B=1, T=3, d=3) x (B=1, d=3, T=3) -> (B=1, T=3, T=3)
assert torch.allclose(naive_scores_3[:, :2, :2], naive_scores_2)
print('Naive scores_3:\n', naive_scores_3)
Naive scores_3:
tensor([[[ 8.0829, -6.9282, -12.7017],
[ -8.0829, -2.3094, -4.6188],
[ 4.6188, 9.2376, -19.6299]]])
# Naive iteration 3: causal masking
naive_mask_3 = torch.tril(torch.ones((3, 3), dtype=torch.bool)).unsqueeze(0) # (3, 3) -> (1, 3, 3) = (T=3, T=3) -> (B=1, T=3, T=3)
naive_scores_3_masked = naive_scores_3.masked_fill(~naive_mask_3, float('-inf')) # (1, 3, 3) with (1, 3, 3) -> (1, 3, 3) = (B=1, T=3, T=3) with (B=1, T=3, T=3) -> (B=1, T=3, T=3)
print('Naive mask_3:\n', naive_mask_3)
print('Naive masked scores_3:\n', naive_scores_3_masked)
Naive mask_3:
tensor([[[ True, False, False],
[ True, True, False],
[ True, True, True]]])
Naive masked scores_3:
tensor([[[ 8.0829, -inf, -inf],
[ -8.0829, -2.3094, -inf],
[ 4.6188, 9.2376, -19.6299]]])
# Naive iteration 3: softmax and value aggregation
naive_weights_3 = torch.softmax(naive_scores_3_masked, dim=-1) # (1, 3, 3) -> (1, 3, 3) = (B=1, T=3, T=3) -> (B=1, T=3, T=3)
naive_context_3 = naive_weights_3 @ naive_V_3 # (1, 3, 3) x (1, 3, 3) -> (1, 3, 3) = (B=1, T=3, T=3) x (B=1, T=3, d=3) -> (B=1, T=3, d=3)
print('Naive weights_3:\n', naive_weights_3)
print('Naive context_3:\n', naive_context_3)
Naive weights_3:
tensor([[[1.0000, 0.0000, 0.0000],
[0.0031, 0.9969, 0.0000],
[0.0098, 0.9902, 0.0000]]])
Naive context_3:
tensor([[[-4.0000, 4.0000, -4.0000],
[ 3.9752, -5.9690, -2.0062],
[ 3.9219, -5.9023, -2.0195]]])
# Naive iteration 3: output projection and sampling
naive_hidden_3 = naive_context_3 @ W_o # (1, 3, 3) x (3, 3) -> (1, 3, 3) = (B=1, T=3, d=3) x (d=3, d=3) -> (B=1, T=3, d=3)
naive_logits_3 = naive_hidden_3[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
naive_next_token_3 = torch.argmax(naive_logits_3, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
naive_prefix_4 = torch.cat([naive_prefix_3, naive_next_token_3.unsqueeze(1)], dim=1) # (1, 3) + (1, 1) -> (1, 4) = (B=1, T=3) + (B=1, T_new=1) -> (B=1, T=4)
print('Naive hidden_3:\n', naive_hidden_3)
print('Naive logits_3:\n', naive_logits_3)
print('Naive iteration 3 next token:', naive_next_token_3.tolist())
Naive hidden_3:
tensor([[[ -4.0000, -4.0000, -12.0000],
[ -0.0124, 17.9318, 21.8946],
[ -0.0391, 17.7851, 21.6679]]])
Naive logits_3:
tensor([[ 7.7656, 17.8242, -17.8632, -25.5116, 3.9219, 21.5897, -3.8437,
-13.8242, -25.4725, 39.4530]])
Naive iteration 3 next token: [9]
Iteration 4¶
One more full-prefix recomputation produces the final token and completes the length-5 sequence.
# Naive iteration 4: embedding lookup
naive_x_4 = token_embedding_table[naive_prefix_4] # (1, 4) -> (1, 4, 3) = (B=1, T=4) -> (B=1, T=4, d=3)
print('Naive iteration 4 prefix:', naive_prefix_4.tolist())
print('Naive iteration 4 embeddings:\n', naive_x_4)
Naive iteration 4 prefix: [[1, 8, 9, 9]]
Naive iteration 4 embeddings:
tensor([[[ 0., -2., -1.],
[-2., 1., 1.],
[ 0., 1., -2.],
[ 0., 1., -2.]]])
# Naive iteration 4: Q/K/V projections
naive_Q_4 = naive_x_4 @ W_q # (1, 4, 3) x (3, 3) -> (1, 4, 3) = (B=1, T=4, d=3) x (d=3, d=3) -> (B=1, T=4, d=3)
naive_K_4 = naive_x_4 @ W_k # (1, 4, 3) x (3, 3) -> (1, 4, 3) = (B=1, T=4, d=3) x (d=3, d=3) -> (B=1, T=4, d=3)
naive_V_4 = naive_x_4 @ W_v # (1, 4, 3) x (3, 3) -> (1, 4, 3) = (B=1, T=4, d=3) x (d=3, d=3) -> (B=1, T=4, d=3)
assert torch.allclose(naive_Q_4[:, :3, :], naive_Q_3)
assert torch.allclose(naive_K_4[:, :3, :], naive_K_3)
assert torch.allclose(naive_V_4[:, :3, :], naive_V_3)
print('Naive Q_4:\n', naive_Q_4)
print('Naive K_4:\n', naive_K_4)
print('Naive V_4:\n', naive_V_4)
Naive Q_4:
tensor([[[ 2., -3., 0.],
[-5., -2., -1.],
[-1., -1., -5.],
[-1., -1., -5.]]])
Naive K_4:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.],
[-2., 6., 6.],
[-2., 6., 6.]]])
Naive V_4:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.],
[ 2., -2., 2.],
[ 2., -2., 2.]]])
# Naive iteration 4: raw attention scores
naive_scores_4 = naive_Q_4 @ naive_K_4.transpose(-2, -1) / math.sqrt(d_model) # (1, 4, 3) x (1, 3, 4) -> (1, 4, 4) = (B=1, T=4, d=3) x (B=1, d=3, T=4) -> (B=1, T=4, T=4)
assert torch.allclose(naive_scores_4[:, :3, :3], naive_scores_3)
print('Naive scores_4:\n', naive_scores_4)
Naive scores_4:
tensor([[[ 8.0829, -6.9282, -12.7017, -12.7017],
[ -8.0829, -2.3094, -4.6188, -4.6188],
[ 4.6188, 9.2376, -19.6299, -19.6299],
[ 4.6188, 9.2376, -19.6299, -19.6299]]])
# Naive iteration 4: causal masking
naive_mask_4 = torch.tril(torch.ones((4, 4), dtype=torch.bool)).unsqueeze(0) # (4, 4) -> (1, 4, 4) = (T=4, T=4) -> (B=1, T=4, T=4)
naive_scores_4_masked = naive_scores_4.masked_fill(~naive_mask_4, float('-inf')) # (1, 4, 4) with (1, 4, 4) -> (1, 4, 4) = (B=1, T=4, T=4) with (B=1, T=4, T=4) -> (B=1, T=4, T=4)
print('Naive mask_4:\n', naive_mask_4)
print('Naive masked scores_4:\n', naive_scores_4_masked)
Naive mask_4:
tensor([[[ True, False, False, False],
[ True, True, False, False],
[ True, True, True, False],
[ True, True, True, True]]])
Naive masked scores_4:
tensor([[[ 8.0829, -inf, -inf, -inf],
[ -8.0829, -2.3094, -inf, -inf],
[ 4.6188, 9.2376, -19.6299, -inf],
[ 4.6188, 9.2376, -19.6299, -19.6299]]])
# Naive iteration 4: softmax and value aggregation
naive_weights_4 = torch.softmax(naive_scores_4_masked, dim=-1) # (1, 4, 4) -> (1, 4, 4) = (B=1, T=4, T=4) -> (B=1, T=4, T=4)
naive_context_4 = naive_weights_4 @ naive_V_4 # (1, 4, 4) x (1, 4, 3) -> (1, 4, 3) = (B=1, T=4, T=4) x (B=1, T=4, d=3) -> (B=1, T=4, d=3)
print('Naive weights_4:\n', naive_weights_4)
print('Naive context_4:\n', naive_context_4)
Naive weights_4:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.0031, 0.9969, 0.0000, 0.0000],
[0.0098, 0.9902, 0.0000, 0.0000],
[0.0098, 0.9902, 0.0000, 0.0000]]])
Naive context_4:
tensor([[[-4.0000, 4.0000, -4.0000],
[ 3.9752, -5.9690, -2.0062],
[ 3.9219, -5.9023, -2.0195],
[ 3.9219, -5.9023, -2.0195]]])
# Naive iteration 4: output projection and sampling
naive_hidden_4 = naive_context_4 @ W_o # (1, 4, 3) x (3, 3) -> (1, 4, 3) = (B=1, T=4, d=3) x (d=3, d=3) -> (B=1, T=4, d=3)
naive_logits_4 = naive_hidden_4[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
naive_next_token_4 = torch.argmax(naive_logits_4, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
naive_prefix_5 = torch.cat([naive_prefix_4, naive_next_token_4.unsqueeze(1)], dim=1) # (1, 4) + (1, 1) -> (1, 5) = (B=1, T=4) + (B=1, T_new=1) -> (B=1, T=5)
print('Naive hidden_4:\n', naive_hidden_4)
print('Naive logits_4:\n', naive_logits_4)
print('Naive iteration 4 next token:', naive_next_token_4.tolist())
print('Naive final sequence:', naive_prefix_5.tolist())
Naive hidden_4:
tensor([[[ -4.0000, -4.0000, -12.0000],
[ -0.0124, 17.9318, 21.8946],
[ -0.0391, 17.7851, 21.6679],
[ -0.0391, 17.7851, 21.6679]]])
Naive logits_4:
tensor([[ 7.7656, 17.8242, -17.8632, -25.5116, 3.9219, 21.5897, -3.8437,
-13.8242, -25.4725, 39.4530]])
Naive iteration 4 next token: [9]
Naive final sequence: [[1, 8, 9, 9, 9]]
# Naive section: total work summary
naive_total_projection_work = int(prefix_lengths.sum()) # (4,) -> () = (num_steps=4,) -> ()
naive_total_score_work = int((prefix_lengths ** 2).sum()) # (4,) -> () = (num_steps=4,) -> ()
print('Naive total tokens processed for Q/K/V:', naive_total_projection_work, '= 1 + 2 + 3 + 4 = O(T^2) over the full decode')
print('Naive total attention scores computed:', naive_total_score_work, '= 1^2 + 2^2 + 3^2 + 4^2 = O(T^3) over the full decode')
Naive total tokens processed for Q/K/V: 10 = 1 + 2 + 3 + 4 = O(T^2) over the full decode Naive total attention scores computed: 30 = 1^2 + 2^2 + 3^2 + 4^2 = O(T^3) over the full decode
KV-Cache Decoding¶
At each step, compute projections only for the newest token and reuse earlier keys and values.
Iteration 1¶
The cache path starts the same way as the naive path because nothing has been cached yet.
# Cache iteration 1: embedding lookup
cache_prefix_1 = start_prefix # (1, 1) = (B=1, T=1)
cache_x_1 = token_embedding_table[cache_prefix_1] # (1, 1) -> (1, 1, 3) = (B=1, T=1) -> (B=1, T=1, d=3)
print('Cache iteration 1 prefix:', cache_prefix_1.tolist())
print('Cache iteration 1 embeddings:\n', cache_x_1)
Cache iteration 1 prefix: [[1]] Cache iteration 1 embeddings: tensor([[[ 0., -2., -1.]]])
# Cache iteration 1: Q/K/V projections and cache initialization
cache_q_1 = cache_x_1 @ W_q # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_k_1 = cache_x_1 @ W_k # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_v_1 = cache_x_1 @ W_v # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
K_cache_1 = cache_k_1 # (1, 1, 3) = (B=1, T=1, d=3)
V_cache_1 = cache_v_1 # (1, 1, 3) = (B=1, T=1, d=3)
assert torch.allclose(cache_q_1, naive_Q_1[:, -1:, :])
assert torch.allclose(K_cache_1, naive_K_1)
assert torch.allclose(V_cache_1, naive_V_1)
print('Cache q_1:\n', cache_q_1)
print('K_cache_1:\n', K_cache_1)
print('V_cache_1:\n', V_cache_1)
Cache q_1: tensor([[[ 2., -3., 0.]]]) K_cache_1: tensor([[[ 4., -2., -2.]]]) V_cache_1: tensor([[[-4., 4., -4.]]])
# Cache iteration 1: raw attention scores
cache_scores_1 = cache_q_1 @ K_cache_1.transpose(-2, -1) / math.sqrt(d_model) # (1, 1, 3) x (1, 3, 1) -> (1, 1, 1) = (B=1, T_new=1, d=3) x (B=1, d=3, T=1) -> (B=1, T_new=1, T=1)
assert torch.allclose(cache_scores_1, naive_scores_1[:, -1:, :])
print('Cache scores_1:\n', cache_scores_1)
Cache scores_1: tensor([[[8.0829]]])
# Cache iteration 1: masking step
# No extra causal mask is needed because the cache already contains only visible positions.
cache_scores_1_masked = cache_scores_1.clone() # (1, 1, 1) -> (1, 1, 1) = (B=1, T_new=1, T=1) -> (B=1, T_new=1, T=1)
print('Cache masked scores_1:\n', cache_scores_1_masked)
Cache masked scores_1: tensor([[[8.0829]]])
# Cache iteration 1: softmax and value aggregation
cache_weights_1 = torch.softmax(cache_scores_1_masked, dim=-1) # (1, 1, 1) -> (1, 1, 1) = (B=1, T_new=1, T=1) -> (B=1, T_new=1, T=1)
cache_context_1 = cache_weights_1 @ V_cache_1 # (1, 1, 1) x (1, 1, 3) -> (1, 1, 3) = (B=1, T_new=1, T=1) x (B=1, T=1, d=3) -> (B=1, T_new=1, d=3)
assert torch.allclose(cache_weights_1, naive_weights_1[:, -1:, :])
assert torch.allclose(cache_context_1, naive_context_1[:, -1:, :])
print('Cache weights_1:\n', cache_weights_1)
print('Cache context_1:\n', cache_context_1)
Cache weights_1: tensor([[[1.]]]) Cache context_1: tensor([[[-4., 4., -4.]]])
# Cache iteration 1: output projection and sampling
cache_hidden_1 = cache_context_1 @ W_o # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_logits_1 = cache_hidden_1[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
cache_next_token_1 = torch.argmax(cache_logits_1, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
cache_prefix_2 = torch.cat([cache_prefix_1, cache_next_token_1.unsqueeze(1)], dim=1) # (1, 1) + (1, 1) -> (1, 2) = (B=1, T=1) + (B=1, T_new=1) -> (B=1, T=2)
assert torch.equal(cache_next_token_1, naive_next_token_1)
print('Cache hidden_1:\n', cache_hidden_1)
print('Cache logits_1:\n', cache_logits_1)
print('Cache iteration 1 next token:', cache_next_token_1.tolist())
Cache hidden_1: tensor([[[ -4., -4., -12.]]]) Cache logits_1: tensor([[-16., 0., -4., 24., -4., -20., 12., 4., 28., -16.]]) Cache iteration 1 next token: [8]
Iteration 2¶
From here on, only the newest token is embedded and projected; older keys and values are reused from the cache.
# Cache iteration 2: embedding lookup for only the newest token
cache_new_token_2 = cache_next_token_1.unsqueeze(1) # (1,) -> (1, 1) = (B=1,) -> (B=1, T_new=1)
cache_x_2 = token_embedding_table[cache_new_token_2] # (1, 1) -> (1, 1, 3) = (B=1, T_new=1) -> (B=1, T_new=1, d=3)
print('Cache iteration 2 prefix:', cache_prefix_2.tolist())
print('Cache iteration 2 new token:', cache_new_token_2.tolist())
print('Cache iteration 2 embedding:\n', cache_x_2)
Cache iteration 2 prefix: [[1, 8]] Cache iteration 2 new token: [[8]] Cache iteration 2 embedding: tensor([[[-2., 1., 1.]]])
# Cache iteration 2: Q/K/V projections and cache append
cache_q_2 = cache_x_2 @ W_q # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_k_2 = cache_x_2 @ W_k # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_v_2 = cache_x_2 @ W_v # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
K_cache_2 = torch.cat([K_cache_1, cache_k_2], dim=1) # (1, 1, 3) + (1, 1, 3) -> (1, 2, 3) = (B=1, T_old=1, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=2, d=3)
V_cache_2 = torch.cat([V_cache_1, cache_v_2], dim=1) # (1, 1, 3) + (1, 1, 3) -> (1, 2, 3) = (B=1, T_old=1, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=2, d=3)
assert torch.allclose(cache_q_2, naive_Q_2[:, -1:, :])
assert torch.allclose(K_cache_2, naive_K_2)
assert torch.allclose(V_cache_2, naive_V_2)
print('Cache q_2:\n', cache_q_2)
print('K_cache_2:\n', K_cache_2)
print('V_cache_2:\n', V_cache_2)
Cache q_2:
tensor([[[-5., -2., -1.]]])
K_cache_2:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.]]])
V_cache_2:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.]]])
# Cache iteration 2: raw attention scores
cache_scores_2 = cache_q_2 @ K_cache_2.transpose(-2, -1) / math.sqrt(d_model) # (1, 1, 3) x (1, 3, 2) -> (1, 1, 2) = (B=1, T_new=1, d=3) x (B=1, d=3, T=2) -> (B=1, T_new=1, T=2)
assert torch.allclose(cache_scores_2, naive_scores_2[:, -1:, :])
print('Cache scores_2:\n', cache_scores_2)
Cache scores_2: tensor([[[-8.0829, -2.3094]]])
# Cache iteration 2: masking step
# No extra causal mask is needed because the cache already contains only visible positions.
cache_scores_2_masked = cache_scores_2.clone() # (1, 1, 2) -> (1, 1, 2) = (B=1, T_new=1, T=2) -> (B=1, T_new=1, T=2)
print('Cache masked scores_2:\n', cache_scores_2_masked)
Cache masked scores_2: tensor([[[-8.0829, -2.3094]]])
# Cache iteration 2: softmax and value aggregation
cache_weights_2 = torch.softmax(cache_scores_2_masked, dim=-1) # (1, 1, 2) -> (1, 1, 2) = (B=1, T_new=1, T=2) -> (B=1, T_new=1, T=2)
cache_context_2 = cache_weights_2 @ V_cache_2 # (1, 1, 2) x (1, 2, 3) -> (1, 1, 3) = (B=1, T_new=1, T=2) x (B=1, T=2, d=3) -> (B=1, T_new=1, d=3)
assert torch.allclose(cache_weights_2, naive_weights_2[:, -1:, :])
assert torch.allclose(cache_context_2, naive_context_2[:, -1:, :])
print('Cache weights_2:\n', cache_weights_2)
print('Cache context_2:\n', cache_context_2)
Cache weights_2: tensor([[[0.0031, 0.9969]]]) Cache context_2: tensor([[[ 3.9752, -5.9690, -2.0062]]])
# Cache iteration 2: output projection and sampling
cache_hidden_2 = cache_context_2 @ W_o # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_logits_2 = cache_hidden_2[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
cache_next_token_2 = torch.argmax(cache_logits_2, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
cache_prefix_3 = torch.cat([cache_prefix_2, cache_next_token_2.unsqueeze(1)], dim=1) # (1, 2) + (1, 1) -> (1, 3) = (B=1, T=2) + (B=1, T_new=1) -> (B=1, T=3)
assert torch.equal(cache_next_token_2, naive_next_token_2)
print('Cache hidden_2:\n', cache_hidden_2)
print('Cache logits_2:\n', cache_logits_2)
print('Cache iteration 2 next token:', cache_next_token_2.tolist())
Cache hidden_2:
tensor([[[-0.0124, 17.9318, 21.8946]]])
Cache logits_2:
tensor([[ 7.9256, 17.9442, -17.9566, -25.8450, 3.9752, 21.8698, -3.9504,
-13.9442, -25.8326, 39.8264]])
Cache iteration 2 next token: [9]
Iteration 3¶
The cache grows by one more key and value, and the new query attends over the whole cached prefix.
# Cache iteration 3: embedding lookup for only the newest token
cache_new_token_3 = cache_next_token_2.unsqueeze(1) # (1,) -> (1, 1) = (B=1,) -> (B=1, T_new=1)
cache_x_3 = token_embedding_table[cache_new_token_3] # (1, 1) -> (1, 1, 3) = (B=1, T_new=1) -> (B=1, T_new=1, d=3)
print('Cache iteration 3 prefix:', cache_prefix_3.tolist())
print('Cache iteration 3 new token:', cache_new_token_3.tolist())
print('Cache iteration 3 embedding:\n', cache_x_3)
Cache iteration 3 prefix: [[1, 8, 9]] Cache iteration 3 new token: [[9]] Cache iteration 3 embedding: tensor([[[ 0., 1., -2.]]])
# Cache iteration 3: Q/K/V projections and cache append
cache_q_3 = cache_x_3 @ W_q # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_k_3 = cache_x_3 @ W_k # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_v_3 = cache_x_3 @ W_v # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
K_cache_3 = torch.cat([K_cache_2, cache_k_3], dim=1) # (1, 2, 3) + (1, 1, 3) -> (1, 3, 3) = (B=1, T_old=2, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=3, d=3)
V_cache_3 = torch.cat([V_cache_2, cache_v_3], dim=1) # (1, 2, 3) + (1, 1, 3) -> (1, 3, 3) = (B=1, T_old=2, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=3, d=3)
assert torch.allclose(cache_q_3, naive_Q_3[:, -1:, :])
assert torch.allclose(K_cache_3, naive_K_3)
assert torch.allclose(V_cache_3, naive_V_3)
print('Cache q_3:\n', cache_q_3)
print('K_cache_3:\n', K_cache_3)
print('V_cache_3:\n', V_cache_3)
Cache q_3:
tensor([[[-1., -1., -5.]]])
K_cache_3:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.],
[-2., 6., 6.]]])
V_cache_3:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.],
[ 2., -2., 2.]]])
# Cache iteration 3: raw attention scores
cache_scores_3 = cache_q_3 @ K_cache_3.transpose(-2, -1) / math.sqrt(d_model) # (1, 1, 3) x (1, 3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (B=1, d=3, T=3) -> (B=1, T_new=1, T=3)
assert torch.allclose(cache_scores_3, naive_scores_3[:, -1:, :])
print('Cache scores_3:\n', cache_scores_3)
Cache scores_3: tensor([[[ 4.6188, 9.2376, -19.6299]]])
# Cache iteration 3: masking step
# No extra causal mask is needed because the cache already contains only visible positions.
cache_scores_3_masked = cache_scores_3.clone() # (1, 1, 3) -> (1, 1, 3) = (B=1, T_new=1, T=3) -> (B=1, T_new=1, T=3)
print('Cache masked scores_3:\n', cache_scores_3_masked)
Cache masked scores_3: tensor([[[ 4.6188, 9.2376, -19.6299]]])
# Cache iteration 3: softmax and value aggregation
cache_weights_3 = torch.softmax(cache_scores_3_masked, dim=-1) # (1, 1, 3) -> (1, 1, 3) = (B=1, T_new=1, T=3) -> (B=1, T_new=1, T=3)
cache_context_3 = cache_weights_3 @ V_cache_3 # (1, 1, 3) x (1, 3, 3) -> (1, 1, 3) = (B=1, T_new=1, T=3) x (B=1, T=3, d=3) -> (B=1, T_new=1, d=3)
assert torch.allclose(cache_weights_3, naive_weights_3[:, -1:, :])
assert torch.allclose(cache_context_3, naive_context_3[:, -1:, :])
print('Cache weights_3:\n', cache_weights_3)
print('Cache context_3:\n', cache_context_3)
Cache weights_3: tensor([[[0.0098, 0.9902, 0.0000]]]) Cache context_3: tensor([[[ 3.9219, -5.9023, -2.0195]]])
# Cache iteration 3: output projection and sampling
cache_hidden_3 = cache_context_3 @ W_o # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_logits_3 = cache_hidden_3[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
cache_next_token_3 = torch.argmax(cache_logits_3, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
cache_prefix_4 = torch.cat([cache_prefix_3, cache_next_token_3.unsqueeze(1)], dim=1) # (1, 3) + (1, 1) -> (1, 4) = (B=1, T=3) + (B=1, T_new=1) -> (B=1, T=4)
assert torch.equal(cache_next_token_3, naive_next_token_3)
print('Cache hidden_3:\n', cache_hidden_3)
print('Cache logits_3:\n', cache_logits_3)
print('Cache iteration 3 next token:', cache_next_token_3.tolist())
Cache hidden_3:
tensor([[[-0.0391, 17.7851, 21.6679]]])
Cache logits_3:
tensor([[ 7.7656, 17.8242, -17.8632, -25.5116, 3.9219, 21.5897, -3.8437,
-13.8242, -25.4725, 39.4530]])
Cache iteration 3 next token: [9]
Iteration 4¶
The final cached step reuses all previous keys and values, appends one more pair, and produces the last token.
# Cache iteration 4: embedding lookup for only the newest token
cache_new_token_4 = cache_next_token_3.unsqueeze(1) # (1,) -> (1, 1) = (B=1,) -> (B=1, T_new=1)
cache_x_4 = token_embedding_table[cache_new_token_4] # (1, 1) -> (1, 1, 3) = (B=1, T_new=1) -> (B=1, T_new=1, d=3)
print('Cache iteration 4 prefix:', cache_prefix_4.tolist())
print('Cache iteration 4 new token:', cache_new_token_4.tolist())
print('Cache iteration 4 embedding:\n', cache_x_4)
Cache iteration 4 prefix: [[1, 8, 9, 9]] Cache iteration 4 new token: [[9]] Cache iteration 4 embedding: tensor([[[ 0., 1., -2.]]])
# Cache iteration 4: Q/K/V projections and cache append
cache_q_4 = cache_x_4 @ W_q # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_k_4 = cache_x_4 @ W_k # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_v_4 = cache_x_4 @ W_v # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
K_cache_4 = torch.cat([K_cache_3, cache_k_4], dim=1) # (1, 3, 3) + (1, 1, 3) -> (1, 4, 3) = (B=1, T_old=3, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=4, d=3)
V_cache_4 = torch.cat([V_cache_3, cache_v_4], dim=1) # (1, 3, 3) + (1, 1, 3) -> (1, 4, 3) = (B=1, T_old=3, d=3) + (B=1, T_new=1, d=3) -> (B=1, T=4, d=3)
assert torch.allclose(cache_q_4, naive_Q_4[:, -1:, :])
assert torch.allclose(K_cache_4, naive_K_4)
assert torch.allclose(V_cache_4, naive_V_4)
print('Cache q_4:\n', cache_q_4)
print('K_cache_4:\n', K_cache_4)
print('V_cache_4:\n', V_cache_4)
Cache q_4:
tensor([[[-1., -1., -5.]]])
K_cache_4:
tensor([[[ 4., -2., -2.],
[ 0., 4., -4.],
[-2., 6., 6.],
[-2., 6., 6.]]])
V_cache_4:
tensor([[[-4., 4., -4.],
[ 4., -6., -2.],
[ 2., -2., 2.],
[ 2., -2., 2.]]])
# Cache iteration 4: raw attention scores
cache_scores_4 = cache_q_4 @ K_cache_4.transpose(-2, -1) / math.sqrt(d_model) # (1, 1, 3) x (1, 3, 4) -> (1, 1, 4) = (B=1, T_new=1, d=3) x (B=1, d=3, T=4) -> (B=1, T_new=1, T=4)
assert torch.allclose(cache_scores_4, naive_scores_4[:, -1:, :])
print('Cache scores_4:\n', cache_scores_4)
Cache scores_4: tensor([[[ 4.6188, 9.2376, -19.6299, -19.6299]]])
# Cache iteration 4: masking step
# No extra causal mask is needed because the cache already contains only visible positions.
cache_scores_4_masked = cache_scores_4.clone() # (1, 1, 4) -> (1, 1, 4) = (B=1, T_new=1, T=4) -> (B=1, T_new=1, T=4)
print('Cache masked scores_4:\n', cache_scores_4_masked)
Cache masked scores_4: tensor([[[ 4.6188, 9.2376, -19.6299, -19.6299]]])
# Cache iteration 4: softmax and value aggregation
cache_weights_4 = torch.softmax(cache_scores_4_masked, dim=-1) # (1, 1, 4) -> (1, 1, 4) = (B=1, T_new=1, T=4) -> (B=1, T_new=1, T=4)
cache_context_4 = cache_weights_4 @ V_cache_4 # (1, 1, 4) x (1, 4, 3) -> (1, 1, 3) = (B=1, T_new=1, T=4) x (B=1, T=4, d=3) -> (B=1, T_new=1, d=3)
assert torch.allclose(cache_weights_4, naive_weights_4[:, -1:, :])
assert torch.allclose(cache_context_4, naive_context_4[:, -1:, :])
print('Cache weights_4:\n', cache_weights_4)
print('Cache context_4:\n', cache_context_4)
Cache weights_4: tensor([[[0.0098, 0.9902, 0.0000, 0.0000]]]) Cache context_4: tensor([[[ 3.9219, -5.9023, -2.0195]]])
# Cache iteration 4: output projection and sampling
cache_hidden_4 = cache_context_4 @ W_o # (1, 1, 3) x (3, 3) -> (1, 1, 3) = (B=1, T_new=1, d=3) x (d=3, d=3) -> (B=1, T_new=1, d=3)
cache_logits_4 = cache_hidden_4[:, -1, :] @ W_vocab # (1, 3) x (3, 10) -> (1, 10) = (B=1, d=3) x (d=3, V=10) -> (B=1, V=10)
cache_next_token_4 = torch.argmax(cache_logits_4, dim=-1) # (1, 10) -> (1,) = (B=1, V=10) -> (B=1,)
cache_prefix_5 = torch.cat([cache_prefix_4, cache_next_token_4.unsqueeze(1)], dim=1) # (1, 4) + (1, 1) -> (1, 5) = (B=1, T=4) + (B=1, T_new=1) -> (B=1, T=5)
assert torch.equal(cache_next_token_4, naive_next_token_4)
print('Cache hidden_4:\n', cache_hidden_4)
print('Cache logits_4:\n', cache_logits_4)
print('Cache iteration 4 next token:', cache_next_token_4.tolist())
assert torch.equal(cache_prefix_5, naive_prefix_5)
assert cache_prefix_5.shape == (batch_size, max_sequence_length)
print('Cache final sequence:', cache_prefix_5.tolist())
print('Final check: naive and cached decoding match exactly.')
Cache hidden_4:
tensor([[[-0.0391, 17.7851, 21.6679]]])
Cache logits_4:
tensor([[ 7.7656, 17.8242, -17.8632, -25.5116, 3.9219, 21.5897, -3.8437,
-13.8242, -25.4725, 39.4530]])
Cache iteration 4 next token: [9]
Cache final sequence: [[1, 8, 9, 9, 9]]
Final check: naive and cached decoding match exactly.
# Cache section: total work summary
cached_total_projection_work = int(np.ones_like(prefix_lengths).sum()) # (4,) -> () = (num_steps=4,) -> ()
cached_total_score_work = int(prefix_lengths.sum()) # (4,) -> () = (num_steps=4,) -> ()
projection_reduction_factor = naive_total_projection_work / cached_total_projection_work # () / () -> ()
score_reduction_factor = naive_total_score_work / cached_total_score_work # () / () -> ()
print('Cache total tokens processed for Q/K/V:', cached_total_projection_work, '= 1 + 1 + 1 + 1 = O(T) over the full decode')
print('Cache total attention scores computed:', cached_total_score_work, '= 1 + 2 + 3 + 4 = O(T^2) over the full decode')
print('Q/K/V token processing reduction factor:', f'{projection_reduction_factor:.2f}x')
print('Attention score reduction factor:', f'{score_reduction_factor:.2f}x')
Cache total tokens processed for Q/K/V: 4 = 1 + 1 + 1 + 1 = O(T) over the full decode Cache total attention scores computed: 10 = 1 + 2 + 3 + 4 = O(T^2) over the full decode Q/K/V token processing reduction factor: 2.50x Attention score reduction factor: 3.00x