The proliferation of state-space models (SSMs) has been tethered to proprietary, hardware-specific acceleration kernels, primarily for NVIDIA GPUs. This dependency creates a significant barrier to broader adoption and experimentation across different hardware ecosystems. Cosmo Santoni's work challenges this paradigm, demonstrating that the architectural nuances of Mamba 2 map effectively onto standard compiler optimizations.
XLA Unlocks SSM Performance Without Custom Kernels
The core insight presented is that Mamba 2's state space duality, characterized by its diagonal state structure, chunkable recurrence, and einsum-dominated computation with static control flow, aligns precisely with what the XLA compiler is designed to optimize. By leveraging XLA's fusion and tiling passes, the researchers implemented the full inference path, including prefill and cached autoregressive decoding, using only shaped standard primitives. This eliminates the need for hand-written CUDA or Triton kernels, making the architecture performant on any platform with a mature XLA backend, including CPUs, NVIDIA GPUs, and Google Cloud TPUs, all from a single JAX source. This implementation of Mamba 2 JAX showcases a significant step towards hardware-agnostic AI model deployment.