Self-Attention as Kernel Regression in JAX/Flax NNX

· 12 min read

#ml#attention#kernels#interpretability#transformers#rkhs#jax#flax#nnx#implementation#nadaraya-watson#self-attention

Explainer companionAttention is Explainable Because it is a KernelWant the full intuition first? This is the runnable companion to the explainer.Read the explainer

The explainer argued that self-attention is a Nadaraya–Watson kernel smoother that has learned its kernel and its targets end-to-end, and that this single fact, not “the matrix is plottable”, is what makes attention heads readable. This is the implementation companion: build attention from scratch in Flax NNX, turn each prose claim into an assert, and watch the one caveat (the kernel is not symmetric, so not Mercer in the strict sense) fall out of the numbers. The dimensions are tiny so every array is legible.

Scaled dot-product attention, by hand

A single attention head is three learned projections and a softmax. In nnx.Linear the kernel maps the last axis, so the module reads the same whether x is one token or a sequence.

import jax
import jax.numpy as jnp
from flax import nnx


class Head(nnx.Module):
    def __init__(self, d_model, d_head, *, rngs: nnx.Rngs):
        self.wq = nnx.Linear(d_model, d_head, use_bias=False, rngs=rngs)
        self.wk = nnx.Linear(d_model, d_head, use_bias=False, rngs=rngs)
        self.wv = nnx.Linear(d_model, d_head, use_bias=False, rngs=rngs)

    def __call__(self, x):                       # x: [L, d_model]
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        d = q.shape[-1]
        scores = q @ k.T / jnp.sqrt(d)           # [L, L]
        alpha = jax.nn.softmax(scores, axis=-1)  # rows are probability vectors
        return alpha @ v, alpha


head = Head(d_model=16, d_head=16, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(1), (6, 16))
y, alpha = head(x)

That is the whole object. alpha is a row-stochastic matrix, each row a distribution over tokens, and y is the row-weighted average of the value vectors. Everything the explainer says follows from those two sentences.

It is Nadaraya–Watson, and you can assert it

The classical Nadaraya–Watson estimator predicts at a query by normalizing a kernel over observations and using it to weight the targets: f^(x)=iK(x,xi)jK(x,xj)yi\hat f(x)=\sum_i \frac{K(x,x_i)}{\sum_j K(x,x_j)}\,y_i. Write that estimator with no reference to attention, just a kernel, a normalizer, and a weighted sum of targets, and feed it the exp-dot-product kernel:

def nadaraya_watson(q, k, v, kernel):
    K = kernel(q, k)                             # [Lq, Lk] nonnegative similarities
    w = K / K.sum(axis=-1, keepdims=True)        # normalize: rows sum to one
    return w @ v, w


q, kk, v = head.wq(x), head.wk(x), head.wv(x)
exp_dot = lambda q, k: jnp.exp(q @ k.T / jnp.sqrt(q.shape[-1]))

y_nw, w_nw = nadaraya_watson(q, kk, v, exp_dot)

assert jnp.allclose(y_nw, y, atol=1e-5)          # same output
assert jnp.allclose(w_nw, alpha, atol=1e-5)      # same weights

The two are not analogous; they are the same array up to relabeling. The dictionary is exact:

Kernel regressionSelf-attentionIn the code
query point xxquery projection qiq_ihead.wq(x)
observation xix_ikey projection kjk_jhead.wk(x)
kernel K(x,xi)K(x,x_i)exp(qi ⁣ ⁣kj/d)\exp(q_i\!\cdot\! k_j/\sqrt d)exp_dot
target yiy_ivalue projection vjv_jhead.wv(x)
normalizer jK\sum_j Ksoftmax denominatorK.sum(-1)

Softmax is not a special activation that happens to live inside transformers. It is the Nadaraya–Watson normalizer applied to one specific kernel.

The kernel is not symmetric, that is the whole RKHS caveat

The explainer flags one wrinkle: queries and keys use different projections WQW_Q and WKW_K, so the score is a bilinear form qi ⁣ ⁣kj=xi(WQWK)xjq_i\!\cdot\! k_j = x_i^\top (W_Q W_K^\top) x_j whose operator M=WQWKM = W_Q W_K^\top has no reason to be symmetric. A genuine (Mercer) kernel must be symmetric and positive semi-definite. This one is neither, and you can watch it fail:

wq, wk = head.wq.kernel.value, head.wk.kernel.value   # [d_model, d_head]
M = wq @ wk.T                                          # score_ij = x_i^T M x_j
print("operator symmetric? ", bool(jnp.allclose(M, M.T)))   # False

S = x @ M @ x.T                                        # raw scores q_i · k_j
sym_eigs = jnp.linalg.eigvalsh(0.5 * (S + S.T))       # symmetric part
print("score symmetric?   ", bool(jnp.allclose(S, S.T)))    # False
print("eig span of sym part:", float(sym_eigs.min()), float(sym_eigs.max()))
# min < 0 < max  -> indefinite, so NOT positive semi-definite, so not Mercer

Contrast with the symmetric exp-inner-product kernel exp(z ⁣ ⁣z/d)\exp(z\!\cdot\! z'/\sqrt d) on a single set of vectors, which is positive definite (it is a Hadamard exponential of a Gram matrix, a Schur product of PSD matrices):

z = jax.random.normal(jax.random.key(2), (6, 16))
G = jnp.exp(z @ z.T / jnp.sqrt(16))
print("symmetric exp-dot min eig:", float(jnp.linalg.eigvalsh(G).min()))  # >= 0

So the non-Mercer-ness comes entirely from WQWKW_Q \neq W_K, not from the exponential. Here is the subtle part the explainer leans on: softmax rescues the readout but not the kernel. Normalizing the rows makes alpha a stochastic matrix, nonnegative weights summing to one, so the output is still an honest convex combination of values (the attribution story survives). Positive-definiteness of the kernel is a different property, and it stays broken. The convex-readout companion makes the same nonnegative-is-not-positive-definite distinction from the readout side.

A score matrix going from symmetric to asymmetric while an eigenvalue bar chart grows negative bars, labelled indefinite, not a kernel
The moment the kernel breaks, computed in JAX. Start with WK=WQW_K = W_Q: the score matrix xiMxjx_i^\top M x_j is symmetric and its eigenvalues are all nonnegative, a genuine Mercer kernel (a Gram matrix). Slide WKW_K off WQW_Q and the matrix loses symmetry; the eigenvalues of its symmetric part cross zero (red bars). Real attention lives at the right-hand end. Softmax will still make each row nonnegative, but the kernel underneath is no longer positive definite. From scripts/render_attention_psd_gif.py.

Swapping the kernel

If attention is Nadaraya–Watson, the exp-dot-product is just one entry in the kernel slot. Drop in others through the same normalizer and ask the only question that decides whether the weights are an attribution: are they nonnegative?

def sqdist(a, b):
    return jnp.sum((a[:, None, :] - b[None, :, :]) ** 2, axis=-1)

d = q.shape[-1]
kernels = {
    "exp-dot (softmax)": lambda q, k: jnp.exp(q @ k.T / jnp.sqrt(d)),
    "gaussian":          lambda q, k: jnp.exp(-sqdist(q, k) / (2.0 * d)),
    "yat":               lambda q, k: (q @ k.T) ** 2 / (sqdist(q, k) + 1e-3),
    "linear (signed)":   lambda q, k: q @ k.T,
}

for name, kf in kernels.items():
    _, w = nadaraya_watson(q, kk, v, kf)
    nonneg = bool((w >= -1e-8).all())
    sums_to_one = bool(jnp.allclose(w.sum(-1), 1.0, atol=1e-5))
    print(f"{name:18s} nonneg={nonneg!s:5}  rows_sum_to_1={sums_to_one}")
# exp-dot, gaussian, yat -> nonneg=True   (valid mixtures; "30% from token j" means something)
# linear (signed)        -> nonneg=False  (weights can be negative, denominator can cross zero)

The exp-dot-product, the Gaussian, and the Yat kernel all keep the row a probability vector, so the output stays inside the convex hull of the values and “this token contributed αij\alpha_{ij}” is a real statement. The linear kernel divides by a sum that can pass through zero and hands back signed “weights”, a smoother in algebra only. Nonnegativity, not positive-definiteness, is what buys the interpretation.

A query sweeping a loop on the left; on the right four coloured output points, three staying inside the value hull and one escaping it
The same values, four kernels, one normalizer. As the query sweeps a loop (left), each kernel turns it into a different weighting of the same value vectors, and the output y=jwjvjy=\sum_j w_j v_j traces a path (right). The exp-dot-product, Gaussian, and Yat kernels keep ww a convex partition of unity, so yy never leaves the convex hull of the values. The signed linear kernel produces negative weights and a normalizer that crosses zero (✗), flinging yy outside the hull, a mixture only in algebra. From scripts/render_attention_kernels_gif.py.

The √d is a bandwidth

In a kernel smoother the bandwidth sets how far the averaging reaches. In attention the same dial is the temperature 1/d1/\sqrt d on the scores. Sweep it and read the mean entropy of the rows:

def weights_at(scale):
    return jax.nn.softmax((q @ kk.T) * scale, axis=-1)

def mean_entropy(w):
    return float(-jnp.sum(w * jnp.log(w + 1e-12), axis=-1).mean())

for scale in [0.0, 1.0 / jnp.sqrt(d), 1.0, 6.0]:
    print(f"scale={float(scale):.3f}  mean row entropy={mean_entropy(weights_at(scale)):.3f}")
# scale 0   -> uniform averaging, entropy ~ log(L)  (a wide kernel: every token counts equally)
# large     -> entropy -> 0, a near one-hot row     (a narrow kernel: nearest-neighbor copy)

A wide bandwidth makes the head a blur over all tokens; a narrow one makes it a hard copy from the single most similar token. The “copy” and induction heads the interpretability literature names are kernel smoothers run at the sharp end of this dial, the temperature is not a training nuisance, it is the bandwidth of the regression.

An attention row morphing from a flat uniform bar chart to a single spike as the scale increases, while the output point slides from the centroid of the values toward one vertex
Temperature as bandwidth, swept in JAX. At a wide bandwidth (small scale) the row is near-uniform, maximum entropy, every value averaged, and the output sits at the centroid. As the bandwidth narrows the row collapses to a spike (entropy → 0) and the output slides out to a single value vertex: a hard copy. The operating point 1/d1/\sqrt d is one stop along this dial. From scripts/render_attention_bandwidth_gif.py.

A head that routes to a marked token

To see the kernel learn, give one head a job: in each sequence one token is flagged (a large value in its first coordinate), and every position must output that token’s vector. The head can only succeed by learning WQ,WKW_Q, W_K that route attention onto the flag and a WVW_V that copies. One detail earns its keep here, the projections need a bias. Every query has to point at the same token regardless of its own content, and a bias-free linear map of mean-zero inputs cannot produce a content-independent pull in a fixed direction; the constant term is exactly what lets the query lean on the flag. Written batched with einsum, the module is otherwise unchanged in spirit:

import optax


class BatchHead(nnx.Module):
    def __init__(self, d_model, d_head, *, rngs: nnx.Rngs):
        self.wq = nnx.Linear(d_model, d_head, use_bias=True, rngs=rngs)
        self.wk = nnx.Linear(d_model, d_head, use_bias=True, rngs=rngs)
        self.wv = nnx.Linear(d_model, d_head, use_bias=True, rngs=rngs)

    def __call__(self, x):                                   # x: [..., L, d_model]
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        scores = jnp.einsum('...id,...jd->...ij', q, k) / jnp.sqrt(q.shape[-1])
        a = jax.nn.softmax(scores, axis=-1)
        return jnp.einsum('...ij,...jd->...id', a, v), a


def make_example(key):
    k1, k2 = jax.random.split(key)
    x = jax.random.normal(k1, (6, 16))
    t = jax.random.randint(k2, (), 0, 6)                     # the marked position
    x = x.at[t, 0].set(6.0)                                  # the flag
    target = jnp.broadcast_to(x[t], (6, 16))                # every row should become token t
    return x, target


xs, targets = jax.vmap(make_example)(jax.random.split(jax.random.key(3), 512))

model = BatchHead(d_model=16, d_head=16, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-2), wrt=nnx.Param)


@nnx.jit
def train_step(model, optimizer, xs, targets):
    def loss_fn(model):
        ys, _ = model(xs)
        return jnp.mean((ys - targets) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss


for step in range(600):
    loss = train_step(model, optimizer, xs, targets)
print("final loss:", float(loss))      # ~0.0, the target is exactly reconstructable

The head drives the loss to zero, and the learned kernel does exactly what the task demands, it concentrates all of its weight on the flagged token:

_, a = model(xs)
routed = a[:, 0, :].argmax(-1)          # where position 0 sends its weight, per sequence
marker = xs[:, :, 0].argmax(-1)         # the flagged token, per sequence
print("routes to the marker:", float((routed == marker).mean()))   # 1.0

No supervision named the marker; the head discovered that the way to minimize the loss is to put a sharp kernel on it. That discovered kernel is the explanation of what the head does, which is the explainer’s whole point, now a printed number.

An attention matrix sharpening from a near-uniform blur to a single bright column over the flagged token, beside a loss curve falling from ~5 to 2e-5
The same head, trained in JAX. Left: the attention matrix for one fixed example, snapshotted every twelve steps. It begins as a wide, near-uniform kernel (every token averaged) and collapses onto the single column of the flagged token, a learned nearest-neighbour copy, not a rule anyone wrote in. Right: the mean-squared error it is driven by, falling to ~2×1052\times10^{-5}. The figure is generated by scripts/render_attention_kernel_gif.py, which runs the exact training loop above and draws the kernel each frame.

What this leaves out

A real attention layer stacks many heads, carries a causal mask (a jnp.tril that sets the upper triangle to -jnp.inf before the softmax), and caches keys and values at inference. None of that touches the kernel reading: each head is its own Nadaraya–Watson smoother with its own learned kernel, and the mask just restricts which observations a query is allowed to see.

The non-positive-definiteness is also not only a curiosity. It is exactly what the linear-attention companion has to confront: to make attention cheap you replace the exp-dot-product kernel with an explicit feature map ϕ(q)ϕ(k)\phi(q)^\top\phi(k), and the honesty of that approximation (Performer; Katharopoulos et al.) is a statement about how well a positive-definite surrogate stands in for a kernel that never was one. The smoother view is what makes that question askable.

A true kernel heatmap beside its rank-m random-feature approximation, with a relative-error curve falling like one over root m as m grows
The payoff, in JAX. The full kernel exp(q ⁣ ⁣k)\exp(q\!\cdot\! k) (left) is approximated by ϕ(q)ϕ(k)\phi(q)^\top\phi(k) with mm positive random features (middle), the Performer/FAVOR+ map ϕ(x)i=exp(wi ⁣ ⁣xx2/2)/m\phi(x)_i=\exp(w_i\!\cdot\! x-\lVert x\rVert^2/2)/\sqrt m. As mm grows the approximation tightens and the relative Frobenius error falls like 1/m1/\sqrt m (right). The kernel that was never positive definite gets stood in for by one that is, and that surrogate is what buys linear-time attention. From scripts/render_attention_features_gif.py.

References: Flax NNX Module API; the kernel lens on attention from Tsai et al. (2019); Nadaraya–Watson regression from Nadaraya (1964) and Watson (1964); linear-attention feature maps from Katharopoulos et al. (2020) and Choromanski et al. (2021).

Cite as

Bouhsine, T. (). Self-Attention as Kernel Regression in JAX/Flax NNX. Records of the !mmortal Data Scientist. https://tahabouhsine.com/blog/attention-is-kernel-jax-flax-nnx/

BibTeX
@misc{bouhsine2026attentioniskerneljaxflaxnnx,
  author       = {Bouhsine, Taha},
  title        = {Self-Attention as Kernel Regression in JAX/Flax NNX},
  year         = {2026},
  month        = {jun},
  howpublished = {\url{https://tahabouhsine.com/blog/attention-is-kernel-jax-flax-nnx/}},
  note         = {Blog post, Records of the !mmortal Data Scientist}
}

References

  1. Nadaraya, E. A. (1964). On Estimating Regression. Theory of Probability & Its Applications 9(1), 141–142.
  2. Watson, G. S. (1964). Smooth Regression Analysis. Sankhyā: The Indian Journal of Statistics, Series A 26(4), 359–372.
  3. Tsai, Y.-H. H., et al. (2019). Transformer Dissection: An Unified Understanding for Transformer's Attention via the Lens of Kernel. EMNLP-IJCNLP 2019.
  4. Katharopoulos, A., Vyas, A., Pappas, N., Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML 2020.arXiv:2006.16236
  5. Choromanski, K., et al. (2021). Rethinking Attention with Performers. ICLR 2021.arXiv:2009.14794
  6. Bouhsine, T. (2026). A Universal Reproducing Kernel Hilbert Space from Polynomial Alignment and IMQ Distance. arXiv:2605.03262