Flax
Flax NNX implementations of attention, linear attention, and prototype readouts, written as nnx.Module with the linear algebra running live.
-
Q and K Projections in JAX/Flax NNX
A runnable companion to Why Attention Needs Q and K Projections: build scaled dot-product attention with separate query and key projections in Flax NNX, pull the bilinear form B = W_Q W_Kᵀ out of the module, split it into a symmetric metric and an antisymmetric directed part, wire a toy induction head, add RoPE, and measure the low-rank budget and the gauge freedom, all in plain JAX.
-
The Prototype Readout in JAX/Flax NNX
A runnable companion to The Readout is a Convex Combination of Prototypes: read the columns of W_out as output prototypes in Flax NNX, measure the convex/conic/affine/linear regimes numerically, then build a Nadaraya–Watson kernel readout that is convex by construction (nonnegative weights that sum to one, a point that never leaves the prototype hull), with the nonnegativity-vs-positive-definiteness distinction checked in code.
-
Cheap Attention in JAX/Flax NNX
A runnable companion to Cheap Attention: implement positive-feature linear attention in JAX and Flax NNX, watch the all-pairs ledger turn into a shared feature state, and see exactly where the N×N matrix disappears.