LLM inference, step by step (what actually happens when you type a prompt)

Getting your Trinity Audio player ready…

Below is the end-to-end path a modern decoder-only Transformer (the common LLM) takes at inference time—i.e., with fixed, trained weights—from your characters to its next token, and then it repeats the last steps to write a whole response.


0) Setup (one-time per session)

  • Load weights: The model’s learned parameters (matrices and biases) are loaded to GPU/TPU memory (often quantized to 8/4-bit to fit).
  • Disable training-only features: No gradient tracking, no weight updates, dropout off.
  • Allocate the KV cache: Memory to store Keys and Values for attention at every layer as tokens are generated. This makes subsequent tokens fast.

1) Text → tokens (tokenization)

  • Your string prompt is split into tokens by a tokenizer (usually a Byte-Pair Encoding/Unigram model).
  • Example:
    "Transformers are great."[ "Transform", "ers", " are", " great", "." ]IDs like [10213, 421, 504, 778, 13].
  • Shape after tokenization: a 1D integer vector of length T (sequence length).

2) Tokens → embeddings (lookup)

  • The model has an embedding tableE of shape (V, d_model) where:
    • V = vocabulary size (e.g., 50k),
    • d_model = hidden width (e.g., 3,072).
  • Each token id t selects row E[t]. Stack them: X ∈ R^{T × d_model}.
  • Intuition: this maps discrete symbols to points in a semantic geometry where distances/angles carry meaning.

3) Add positional information

Transformers are permutation-invariant, so we inject order:

  • Absolute: add learned P[pos] to each token embedding.
  • Sinusoidal: add fixed sines/cosines.
  • RoPE (rotary): apply a complex-like rotation to Q/K vectors so attention becomes position-aware.

After this step you still have X ∈ R^{T × d_model}, now encoding what and where.


4) Pass through L stacked Transformer blocks

Each block updates the sequence representation via two sublayers:

  1. Self-Attention sublayer (many heads in parallel)
  2. Feed-Forward Network (FFN/MLP) sublayer

Residual connections and layer norms keep signals stable.

Typical block order (pre-norm style):

H0 = X
for ℓ in 1..L:
    A = H_{ℓ-1} + Attention( LayerNorm(H_{ℓ-1}) )
    Hℓ = A + FFN( LayerNorm(A) )

Everything below details those two sublayers.


5) The attention math (what “pays attention to what”)

5a) Make Q, K, V (linear projections)

For each token vector H ∈ R^{T × d_model}:

Q = H W_Q        # (T × d_model) · (d_model × d_k·h) → (T × d_k·h)
K = H W_K
V = H W_V
  • Split into h heads: reshape (T × d_k·h)(h × T × d_k), where d_k = d_model / h.

5b) Scaled dot-product attention per head

For each head:

scores = (Q K^T) / sqrt(d_k)           # (T × d_k) · (d_k × T) → (T × T)
scores += mask                          # causal mask: forbid looking to future tokens
weights = softmax(scores, axis=last)    # row-wise softmax over T keys
head_out = weights V                    # (T × T) · (T × d_k) → (T × d_k)
  • Causal mask enforces left-to-right generation.
  • Softmax turns similarity into attention weights.

5c) Combine heads and project back

  • Concatenate heads: (h × T × d_k)(T × (h·d_k)) = (T × d_model).
  • Final linear mix: AttnOut = Concat(heads) W_O, shape (T × d_model).

Interpretation: Each token becomes a weighted blend of other tokens’ value vectors, where weights depend on query-key similarity and position. This is the model’s “where to look” mechanism.


6) The FFN (nonlinear mixing per position)

Each position (token) goes through an MLP independently:

U = LayerNorm(A)                       # (T × d_model)
Z = U W1 + b1                          # (T × d_ff)
Z = act(Z)                             # GeLU/Swish; with SwiGLU: split & gate
Y = Z W2 + b2                          # (T × d_model)
H = A + Y                              # residual
  • d_ff is larger than d_model (e.g., 4×–8×), giving capacity to re-express features.
  • Many modern models use SwiGLU/GeGLU for better expressivity.
  • In MoE models, W1/W2 are replaced by routed expert MLPs; at inference only a few experts are active per token.

7) Stack repeats & build deep features

  • Repeating attention + FFN across L layers progressively sculpts the representation:
    • Lower layers: local patterns (morphology, short-range syntax).
    • Middle layers: syntax/semantics, entity linking, coreference.
    • Higher layers: long-range dependencies, world knowledge, planning patterns.
  • The residual stream (the thing being added to at every step) carries forward an ever-richer vector at each token position—your running semantic state.

8) Language Modeling Head → logits

After the final block, take the last token’s hidden state (for next-token prediction) or all states (to score every position). Project to vocabulary size:

  • Weight tying is common: reuse the embedding matrix E (transpose) as the output head.
logits = H E^T          # (T × d_model) · (d_model × V) → (T × V)
  • The last row (position T) holds scores for the next token across all V tokens.

9) Turn logits into a token choice (sampling)

  • Convert to probabilities: p = softmax(logits_T).
  • Apply decoding policy:
    • Greedy: pick argmax.
    • Temperature τ: divide logits by τ; τ<1 sharpens, τ>1 diversifies.
    • Top-k / Top-p (nucleus): restrict to the k most likely or smallest set summing to p (e.g., 0.9), then sample.
    • Repetition/penalty tricks: adjust logits to avoid loops.

This yields the next token id.


10) Autoregressive loop (generate more)

  • Append the new token to the sequence.
  • KV cache update: For that token, compute its Q/K/V in each layer; store K and V to the cache so the model doesn’t recompute them for older tokens next step.
  • For the next token:
    • Only compute Q freshly for the last position; K/V for all previous positions are read from the cache.
    • Attention then uses [K_cached; K_new] and [V_cached; V_new].
  • Repeat steps 5–9 until end conditions (EOS token, length, or stop sequence).

This turns O(T²) recompute into O(T) incremental cost after the first pass.


11) Where “matrix MATH” happens (shapes and flow)

To make the linear algebra concrete, consider a single layer with:

  • T = sequence length,
  • d_model = 3072,
  • h = 24 heads,
  • d_k = d_model / h = 128,
  • d_ff = 12288.

Key matrix multiplications at inference (per step):

  1. Q/K/V projections
    (T × d_model) · (d_model × d_k·h)(T × d_k·h) (done three times)
  2. Attention scores (per head)
    (T × d_k) · (d_k × T)(T × T) then softmax + mask
  3. Weighted sum
    (T × T) · (T × d_k)(T × d_k)
  4. Heads combine
    Concat heads → (T × d_model); then (T × d_model) · (d_model × d_model)(T × d_model)
  5. FFN
    (T × d_model) · (d_model × d_ff)(T × d_ff) → nonlinearity → (T × d_ff) · (d_ff × d_model)(T × d_model)
  6. LM head
    (T × d_model) · (d_model × V)(T × V)

All parameters are fixed at inference. The model is just executing this computational graph.


12) Why the output is “rational” (without hidden hand-waving)

  • During training, the model minimized next-token loss over trillions of tokens.
  • The learned weights encode statistical regularities: grammar, facts, chains of implication, formats, plans—compressed into the geometry of the embedding and hidden spaces.
  • At inference, attention routes information along the graph that solved those training tasks best. When your prompt resembles patterns seen during training, the model follows those circuits to produce high-probability continuations—which we experience as reasoning, explanation, or planning.
  • Tools like System prompts, few-shot exemplars, and tool-augmented prompting nudge the routing and the location in the semantic manifold the model moves through while decoding.

13) Practical extras you often see in modern LLMs

  • Grouped/Multi-Query Attention: share K/V across heads to reduce memory bandwidth while keeping multiple Q projections.
  • Flash/Scaled Attention kernels: fused GPU kernels that compute softmax attention with less memory IO.
  • Rotary Position Embeddings (RoPE) scaling: tricks to extend context length (e.g., NTK-aware, YaRN).
  • Speculative decoding: a small draft model proposes tokens, big model verifies, improving throughput.
  • Vision tokens / Function tokens: extra modalities or tool calls are serialized as special tokens that join the same pipeline.
  • Logit bias/ban lists: small adjustments to steer style or prevent certain outputs.

14) A compact pseudocode of the whole loop

ids = tokenize(prompt)                     # [T]
H = E[ids]                                 # (T, d_model)
H = add_positions(H)                       # abs/sin/RoPE

for ℓ in range(L):
    U = layernorm[ℓ](H)
    Q = U @ WQ[ℓ]; K = U @ WK[ℓ]; V = U @ WV[ℓ]      # split heads
    attn = softmax((Q @ K.T) / sqrt(d_k) + causal_mask)
    A = (attn @ V) @ WO[ℓ]                           # merge heads
    H = H + A                                        # residual

    U = layernorm2[ℓ](H)
    H = H + ffn[ℓ](U)                                # MLP

logits = H[-1] @ E.T                                 # tie weights
next_id = sample_from(logits, strategy="top_p", p=0.9, temp=0.8)
append(next_id); update_kv_cache(); repeat...

15) Mental model (why “semantic geometry” helps)

  • Tokenization picks coordinates for known symbols.
  • Embeddings place those symbols in a vector space where meaning is partially linearized.
  • Attention is content-addressable lookup: “find the relevant stuff in context and mix it in.”
  • FFN is nonlinear re-expression: “compress, expand, and recombine features.”
  • Depth (stacked blocks) composes these operations into higher-order abstractions.
  • The LM head projects the final point in this space back onto the vocabulary axes to pick the next symbol.

The “reasoning” you see is the trajectory the last-token vector takes through this manifold as the loop unrolls.


16) TL;DR

  1. Tokenize text → ids.
  2. Lookup embeddings and add position.
  3. Run L times: LayerNorm → Multi-Head Self-Attention (+ residual) → LayerNorm → FFN (+ residual).
  4. Project last state through LM head to get logits over the vocab.
  5. Sample a token.
  6. Cache K/V, append token, and repeat until done.

That’s the entire inference phase: pure, deterministic matrix math (plus a stochastic sampler if you enable it) flowing through a fixed artificial neural network that was shaped during training to make the “rational” continuation the most probable one.


Posted

in

by

Tags:

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *