← back to archive

From QK-clip to Muon clip

Ethan TS. Liu | January 17, 2026

QK-clip is a simple but sharp idea: if attention logits explode, rescale the query and key projections so the largest logit is capped. The same intuition survives in multi-head latent attention, but the shared rotary key projection forces a careful tweak. This post walks through the math and the fix, usually called muon clip.

Why exploding logits are harmful

For a token \(i\), attention weights are

\[ A_{ij} = \frac{\exp(S_{ij})}{\sum_{t} \exp(S_{it})}, \quad S_{ij} = \frac{q_i^\top k_j}{\sqrt{d_k}} \]

Dividing by the largest entry for that row makes the collapse explicit. Let \(j^* = \operatorname*{argmax}_j S_{ij}\), so \(S_{ij^*} \ge S_{ij}\) for all \(j\). It follows that

\[ \frac{A_{ij}}{A_{ij^*}} = \exp(S_{ij} - S_{ij^*}) \]

If the gaps \(S_{ij^*} - S_{ij} \gg 0\) are large, the distribution collapses, so \(A_{ij}\) is essentially zero for most \(j\). This low-entropy regime starves most keys of gradient. Taking the Jacobian makes it clear:

\[ \frac{\partial A_{ij}}{\partial S_{ik}} = A_{ij}(\delta_{jk} - A_{ik}). \]

To see this, write \(A_{ij} = \exp(S_{ij})/Z_i\) with \(Z_i = \sum_t \exp(S_{it})\). Then

\[ \frac{\partial A_{ij}}{\partial S_{ik}} = \frac{\partial}{\partial S_{ik}} \left(\frac{e^{S_{ij}}}{Z_i}\right) = \frac{e^{S_{ij}}\delta_{jk} Z_i - e^{S_{ij}} \frac{\partial Z_i}{\partial S_{ik}}}{Z_i^2} = \frac{e^{S_{ij}}}{Z_i}\left(\delta_{jk} - \frac{e^{S_{ik}}}{Z_i}\right) = A_{ij}(\delta_{jk} - A_{ik}). \]

When \(A_{ij^*} \approx 1\), we get \(\partial A_{ij^*}/\partial S_{ij^*} \approx 0\) and \(\partial A_{ij}/\partial S_{ik} \approx 0\) for \(j \ne j^*\). This means brittle, overconfident attention and vanishing gradients to most keys. Large \(S_{ij}\) also push \(\exp(S_{ij})\) toward numeric overflow in reduced precision, which destabilizes training.

QK-clip in one equation

Let \(X \in \mathbb{R}^{n \times d}\), \(Q = XW^Q\), \(K = XW^K\), and logits

\[ S = \frac{QK^\top}{\sqrt{d_k}}. \]

QK-clip introduces a threshold \(\tau\) and rescales both projections when \(S_{\max} > \tau\):

\[ \gamma = \frac{\tau}{S_{\max}}, \quad W^Q \leftarrow \sqrt{\gamma}\,W^Q, \quad W^K \leftarrow \sqrt{\gamma}\,W^K. \]

The dot product is bilinear, so the logits scale by \(\gamma\): \(S \leftarrow \gamma S\), and the maximum is pulled back to \(\tau\).

Per-head QK-clip in standard multi-head attention

In vanilla multi-head attention, each head \(h\) has its own projections:

\[ Q_h = X W_h^Q, \quad K_h = X W_h^K, \quad V_h = X W_h^V, \quad S_h = \frac{Q_h K_h^\top}{\sqrt{d_k}}. \]

If a single head blows up, we can clip that head alone. Let \(S_{h,\max}\) be the largest logit in head \(h\), define \(\gamma_h = \tau / S_{h,\max}\) when \(S_{h,\max} > \tau\), and rescale its query/key projections:

\[ W_h^Q \leftarrow \sqrt{\gamma_h}\,W_h^Q, \quad W_h^K \leftarrow \sqrt{\gamma_h}\,W_h^K. \]

We do not need to rescale \(W_h^V\): values do not appear in the logits, so scaling \(V_h\) does not prevent softmax saturation. It only rescales the output and can be absorbed by the output projection.

MLA and decoupled RoPE

Multi-head latent attention (MLA) compresses queries and KV into low-rank latents:

\[ C^Q = X W^Q, \quad C^{KV} = X W^{KV}, \quad W^{KV} \in \mathbb{R}^{d \times r}, \quad r \ll d. \]

Each head then uses up-projections to recover its own content queries/keys/values:

\[ Q_h^C = C^Q W_h^{Q_C}, \quad K_h^C = C^{KV} W_h^{K_C}, \quad V_h^C = C^{KV} W_h^V. \]

RoPE does not fit cleanly inside the shared low-rank KV path, because the rotation must act in each head's key space. The standard fix is decoupled RoPE: add rotary projections computed directly from \(X\),

\[ Q_h^R = R_\theta(X W_h^{Q_R}), \quad K^R = R_\theta(X W^{K_R}), \]

where all heads share \(W^{K_R}\) but each head has its own \(W_h^{Q_R}\). The head vectors are then concatenated:

\[ Q_h = [Q_h^C; Q_h^R], \quad K_h = [K_h^C; K^R], \quad V_h = V_h^C. \]

This is where the new clipping constraint appears: the rotary key projection is shared, so it cannot be rescaled separately for each head.

Muon clip: rescale only the rotary queries

We still clip per head, but we must respect what is shared. For head \(h\), define \(\gamma_h = \tau / S_{h,\max}\) when \(S_{h,\max} > \tau\). Then rescale the head-specific up-projections and the head-specific rotary query:

\[ W_h^{Q_C} \leftarrow \sqrt{\gamma_h}\,W_h^{Q_C}, \quad W_h^{K_C} \leftarrow \sqrt{\gamma_h}\,W_h^{K_C}, \quad W_h^{Q_R} \leftarrow \gamma_h\,W_h^{Q_R}, \quad W^{K_R}\ \text{unchanged}. \]

Because \(K^R\) is fixed, scaling \(Q_h^R\) by \(\gamma_h\) scales the rotary dot product by \(\gamma_h\) directly, matching the content term's \(\gamma_h\) factor from the paired \(\sqrt{\gamma_h}\) scalings.

We do not rescale \(W_h^V\) because values do not affect logits, and we do not rescale the shared \(W^{K_R}\) because that would multiply it multiple times across heads. This is the muon clip rule: per-head scaling for the content up-projections and rotary queries, shared rotary key left alone.