← back to archive

Parallel Transformer Blocks

Ethan TS. Liu | January 15, 2026

These are quick notes on why the "parallel" Transformer block is not exactly equivalent to the usual serial attention-then-MLP block. It is best viewed as an approximation or architectural variant of the standard pre-norm Transformer. Here \(\mathrm{Norm}(\cdot)\) denotes the model's normalization (LayerNorm or RMSNorm, depending on the architecture).

Serial pre-norm block

The common two-step form is: \(h = x + \mathrm{Attn}(\mathrm{Norm}(x))\) and \(y = h + \mathrm{MLP}(\mathrm{Norm}(h))\).

Substitute \(h\) into \(y\): \(y = x + \mathrm{Attn}(\mathrm{Norm}(x)) + \mathrm{MLP}(\mathrm{Norm}(x + \mathrm{Attn}(\mathrm{Norm}(x))))\).

Parallel block

The parallel variant uses: \(y = x + \mathrm{Attn}(\mathrm{Norm}(x)) + \mathrm{MLP}(\mathrm{Norm}(x))\).

The only difference is the MLP input: serial uses \(\mathrm{MLP}(\mathrm{Norm}(x + \mathrm{Attn}(\mathrm{Norm}(x))))\), while parallel uses \(\mathrm{MLP}(\mathrm{Norm}(x))\). So the two are identical only if \(\mathrm{Norm}(x + \mathrm{Attn}(\mathrm{Norm}(x))) = \mathrm{Norm}(x)\), which is not true in general. However, they can be approximately equivalent under common conditions. If the residual branch is small compared to \(x\), then \(\mathrm{Norm}(x + \delta) \approx \mathrm{Norm}(x)\) when \(\lVert \delta \rVert \ll \lVert x \rVert\), where \(\delta = \mathrm{Attn}(\mathrm{Norm}(x))\).

First-order view: define \(f(\cdot) = \mathrm{MLP}(\mathrm{Norm}(\cdot))\). Then \(f(x + \delta) \approx f(x) + J_f(x)\,\delta\), so the parallel block keeps \(f(x)\) and drops the interaction term \(J_f(x)\,\delta\) that represents "MLP responding to attention's update within the same block."

In the serial block, attention updates the representation and the MLP processes that updated representation. In the parallel block, both branches see the same \(\mathrm{Norm}(x)\), so the within-block coupling is removed. This decoupling is why the parallel variant is easier to fuse in implementations: both branches can share a single input projection (one matmul that feeds attention and the first MLP linear), which can be faster.

Reference: see the GPT-J note for context on the parallel block and its implementation tradeoffs: https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/